## <span style="color:#DFFF00">0. model example and preprocessing</span>

In [1]:
## Model.py
## Exmaple of Model
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(1024, 2048)
        self.fc2 = nn.Linear(2048, 10)

    def forward(self, x):
        x = F.relu(self.pool(self.conv1(x)))
        x = F.relu(self.pool(self.conv2(x)))
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def cnn():
    return CNN()


## <span style="color:#DFFF00">1. Data Loading and distribution</span>

In [6]:
## Data Loader-MNIST
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from matplotlib import pyplot as plt
import numpy as np
import torch.nn.functional as F

class loader(object):
    def __init__(self, batch_size= 64, type ="NON_IID"):
        self.batch_size = batch_size
        self.type = type
        self.__load_dataset()
        self.__get_index()
        
    def __load_dataset(self):
        # mnist
        self.train_mnist = datasets.MNIST('./dataset/',
                                          train=True,
                                          download=True,
                                          transform=transforms.Compose([
                                              transforms.ToTensor(),
                                              transforms.Normalize((0.1307,), (0.3081,))# batch mean, std
                                          ]))

        self.test_mnist = datasets.MNIST('./dataset/',
                                         train=False,
                                         download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.1307,), (0.3081,))
                                         ]))



    def __get_index(self):
        self.train_dataset = self.train_mnist
        self.test_dataset = self.test_mnist
        self.indices = [[], [], [], [], [], [], [], [], [], []]
        for index, data in enumerate(self.train_dataset):
            self.indices[data[1]].append(index)        
        
    def get_loader(self, rank):
        if not rank:  # 빈 리스트인 경우
            # 시드 값을 랜덤하게 선택 -> server loader
            rank = np.random.randint(100)  # 0부터 99까지의 랜덤한 정수 선택 
        else:
            rank = int(rank[0])  # 정수형으로 변환
            np.random.seed(rank)  # 시드 설정(고정)

        if self.type == "IID":
            num_classes = 10 # 0-9 classes
            num_clients = 50 # number of clients
            sorted_indices = [sorted(self.indices[i]) for i in range(num_classes)]
            label_per_client = len(self.indices[0]) // num_clients
            
            shards = []
            for rank in range(num_clients):
                shard = [] 
                for nc in range(num_classes):
                    label_indices = sorted_indices[nc] 
                    start_idx = rank * label_per_client
                    end_idx = start_idx + label_per_client
                    selected_indices = label_indices[start_idx:end_idx]
                    shard.extend(selected_indices) 
                shards.append(shard)
                
            for rank, shard in enumerate(shards):
                # Check the labels in each client's shard
                labels1 = [self.train_dataset.targets[idx].item() for idx in shard]
                #print(f"Client {rank+1} - Labels of Shard: {labels1}")

            subsets = [torch.utils.data.Subset(self.train_dataset, shard) for shard in shards]
            train_loader = DataLoader(torch.utils.data.ConcatDataset(subsets), batch_size=self.batch_size, shuffle=True)
            test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True)

        elif self.type == "NON_IID":
            num_classes = 10 # 0-9 classes
            num_clients = 50 # number of clients
            num_shards = 200 # shards 개수
            shard_size = 300 # 300개 예제
            sorted_indices = [sorted(self.indices[i]) for i in range(num_classes)] # 데이터를 라벨 별로 정렬
            ###############################
            ## IID 부분과 논문을 참조하여 작성필요. label을 이용한 subset으로 구현해야합니다. 
            ## Your Code..
            label_per_client = shard_size

            shards = []
            for rank in range(num_clients):
                shard = []
                for nc in range(num_classes):
                    label_indices = sorted_indices[nc] 
                    start_idx = rank * label_per_client
                    end_idx = start_idx + label_per_client
                    selected_indices = label_indices[start_idx:end_idx]
                    shard.extend(selected_indices)
                shards.append(shard)
                
            for rank, shard in enumerate(shards):
                # Check the labels in each client's shard
                labels1 = [self.train_dataset.targets[idx].item() for idx in shard]
                #print(f"Client {rank+1} - Labels of Shard: {labels1}")

            subsets = [torch.utils.data.Subset(self.train_dataset, shard) for shard in shards]
            train_loader = DataLoader(torch.utils.data.ConcatDataset(subsets), batch_size=self.batch_size, shuffle=True)
            test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True)

            ###############################
        return train_loader, test_loader

loader_test=loader(type ="IID")
loader_test.get_loader([])

(<torch.utils.data.dataloader.DataLoader at 0x2bedbb5edd0>,
 <torch.utils.data.dataloader.DataLoader at 0x2beda0f2b50>)

## <span style="color:#DFFF00">2. ES, Client class</span>

In [None]:
## Edited day: 2023. 11. 16.

import torch
from torch import optim
from torch import nn
from torch.autograd import Variable
#from Model import model ## Define your model
#from Data_Loader import loader ## Define your data loader

class ES():
    def __init__(self, size, data_loader, load, device):
        self.size        = size ## Number of Clients
        #self.model       = model.to(device)
        self.test_loader = data_loader[1] ## Your test dataset
        self.accuracy    = []
        self.clients     = [None]*size ## Client's parameter memory
        self.count       = 0
        self.load        = load ## Reflect Client k's data size
        self.load_s      = 0 ## Reflect Whole Client's data size
        self.device      = device ## Cuda or CPU
        for i in load:
            self.load_s+=i
            
    def average_weights(self, clients):
        ## Brief: clients리스트에 저장된 각 사용자들의 parameter를 가중평균(load반영) 해줌.
        ## Pre: clients 리스트에 사용자들의 parameter가 저장되어 있어야함.
        ## Retrun:  가중평균된 weight을 반환.
        ## Tip : Torch의 weight은 dictionary형태로 참조 및 합쳐줄 수 있음.
        ## 매개변수의 client는 self.clients이나 특정 사용자 subset으로 한정할 수도 있기에 변수화함.

    def aggregate(self):
        ## Brief: Average된 weight을 global model에 반영 및 test 진행
        ## Pre: self.clients에 각 사용자들의 weight이 저장되어 있음
        ## Post1: global 모델인 self.model에 averaging된 parameter를 load함.
        ## Post2: test accuracy를 메모리에 저장
        
    def global_weight(self):
        ## Brief: 현재 Edge Server에 저장된 Global Weight을 참조
        ## Pre: global 모델이 정의되어 있어야함.
        ## Retrun:  Global Weight을 반환
    
    def test(self):
        ## Brief: 현재 Global model에 대한 test 함수.
        ## Pre: model이 aggregate된 상태여야 함.
        ## Return: test score를 반환

class Client():
    def __init__(self, rank, data_loader, local_epoch, ES):
        # seed
        seed = 19201077 + 19950920 + rank
        torch.manual_seed(seed) ## Random Seed로 random 고정
        self.rank = rank ## Clinet's ID
        self.local_epoch = local_epoch ## FedAVG's local epoch
        self.ES = ES ## Client가 속해있는 Edge Server 인식
        self.test_loader = data_loader[1] ## Test data
        self.train_loader = iter(data_loader[0]) ## Train data

    def load_global_model(self):
        ## Brief: 현재 Global model에 대한 test 함수.
        ## Pre: model이 aggregate된 상태여야 함.
        ## Return: test score를 반환
        return model
    
    def train(self, model):
        ## Brief: local model의 학습을 진행.
        ## Pre: None
        ## Return: 학습된 client k의 weight을 반환.
        
        # optimizer = optim.SGD() // 모델의 optimizer
        # scheduler = optim.lr_scheduler.LambdaLR() // optimizer의 Learning Rate를 epoch에 따라 조절함.
        
        # for _ in range(local_epoch):
            # for data, target in self.train_loader:
                #여기서 학습 진행
            
    def run(self):
        model = self.load_global_model(self)
        weights = self.train(model)
        self.ES.clients[self.ES.count%self.ES.size]=weights
        self.ES.count+=1
        ## count는 circular buffer용임

        

## <span style="color:#DFFF00">3. Run</span>

In [None]:
import copy

def fed_AVG(n_client, n_ES, ES_epoch, CL_epoch, batch_size, type = "NON_IID"):
    print('Initialize Dataset...')
    data_loader = loader('mnist', batch_size=batch_size, type = type)    
    print('Initialize Edge Servers and Clients...')
    ESs = []
    clients = [[ None for i in range(n_client)] for j in range(n_ES) ]

    for i in range(n_ES):
        #ESs.append(ES()) class 내용과 논문 참조해서 매개변수 작성필요.
        for j in range(n_client):
            #clients[i][j]=Client() class 내용과 논문 참조해서 매개변수 작성필요.

    # Federated Learning
    for ESe in range(ES_epoch):
        #print('\n================== Edge Server Epoch {:>3} =================='.format(ESe + 1))
        for ESn in range(n_ES):
            #print("================= Edge Server :",ESn,"process =================")
            for c in clients[ESn]:
                c.run()
                
            ESs[ESn].aggregate()

        for ESn in range(n_ES):
            ESs[ESn].load_global_model()
    
    w = []
    
    for es in ESs:
        w.append(copy.deepcopy(es.global_weight()))

    return w
if __name__ == '__main__':
    w = fed_AVG()