In [1]:
## Model.py
## Exmaple of Model
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(1024, 2048)
        self.fc2 = nn.Linear(2048, 10)

    def forward(self, x):
        x = F.relu(self.pool(self.conv1(x)))
        x = F.relu(self.pool(self.conv2(x)))
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def cnn():
    return CNN()


In [5]:
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import numpy as np
import random

class loader(object):
    def __init__(self, batch_size=64, type="NON_IID"):
        self.batch_size = batch_size
        self.type = type
        self.__load_dataset()
        self.__get_index()

    def __load_dataset(self):
        # mnist
        self.train_mnist = datasets.MNIST('./dataset/',
                                          train=True,
                                          download=True,
                                          transform=transforms.Compose([
                                              transforms.ToTensor(),
                                              transforms.Normalize((0.1307,), (0.3081,))
                                          ]))

        self.test_mnist = datasets.MNIST('./dataset/',
                                         train=False,
                                         download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.1307,), (0.3081,))
                                         ]))

    def __get_index(self):
        self.train_dataset = self.train_mnist
        self.test_dataset = self.test_mnist

        self.indices = [[], [], [], [], [], [], [], [], [], []]
        for index, data in enumerate(self.train_dataset):
            self.indices[data[1]].append(index)

    def get_loader(self, rank):
        if not rank:
            rank = np.random.randint(10)
        else:
            rank = int(rank[0])
            np.random.seed(rank)
        
        if self.type == "NON_IID":
            num_classes = 10
            shards_per_class = 2
            total_shards = num_classes * shards_per_class
            shard_size = len(self.indices[0]) // shards_per_class

            sorted_indices = [sorted(self.indices[i]) for i in range(num_classes)]  

            shards = []
            for shard_idx in range(total_shards):
                class_idx = shard_idx % num_classes  
                start_idx = (shard_idx // num_classes) * shard_size
                end_idx = start_idx + shard_size
                shards.append(sorted_indices[class_idx][start_idx:end_idx])

            selected_classes = random.sample(range(num_classes), 5)  
            selected_shards = [shard for class_idx in selected_classes for shard in shards[class_idx::num_classes]]

            for rank, shard in enumerate(selected_shards):
                selected_clients = shard
                labels1 = [self.train_dataset.targets[idx].item() for idx in selected_clients]

            subsets = [torch.utils.data.Subset(self.train_dataset, shard) for shard in selected_shards]  
            train_loader = DataLoader(torch.utils.data.ConcatDataset(subsets), batch_size=self.batch_size, shuffle=True)
            test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True)


        elif self.type == "IID":
            num_classes = 10
            shards_per_class = 5
            total_shards = num_classes * shards_per_class
            shard_size = len(self.indices[0]) // shards_per_class  

            sorted_indices = [sorted(self.indices[i]) for i in range(num_classes)]  

            shards = []
            for shard_idx in range(total_shards):
                class_idx = shard_idx % num_classes  # Shard index determines label
                start_idx = (shard_idx // num_classes) * shard_size
                end_idx = start_idx + shard_size
                shards.append(sorted_indices[class_idx][start_idx:end_idx])

            for rank, shard in enumerate(shards):
                selected_clients = shard
                labels1 = [self.train_dataset.targets[idx].item() for idx in selected_clients]

            subsets = [torch.utils.data.Subset(self.train_dataset, shard) for shard in shards]  
            train_loader = DataLoader(torch.utils.data.ConcatDataset(subsets), batch_size=self.batch_size, shuffle=True)
            test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True)

        return train_loader, test_loader
print()  
loader1 = loader()
loader1.get_loader([1])  




(<torch.utils.data.dataloader.DataLoader at 0x22d44e368b0>,
 <torch.utils.data.dataloader.DataLoader at 0x22d44e36bb0>)

In [3]:
## Edge_Server.py
import torch
from torch import optim
from torch import nn
from torch.autograd import Variable

class ES():
    def __init__(self, size, data_loader, device):
        self.size = size
        self.test_loader = data_loader[1]
        self.accuracy = []
        self.clients = [None]*size
        self.count = 0
        self.model =  CNN().to(device)
        self.device = device

        
    def average_weights(self,clients):
        for info in clients[1:]:
            for key in info:
                clients[0][key]=info[key] + clients[0][key]
        for key in clients[0]:
            clients[0][key]=clients[0][key]/self.size  
        weights=clients[0]
        return weights

    def aggregate(self):
        weights_info = self.clients
        weights = self.__average_weights(weights_info)
        self.model.load_state_dict(weights)
        self.CS.ESs[self.CS.count%self.CS.size]=weights
        self.CS.count+=1
        #test_accuracy = self.__test()
        #self.accuracy.append(test_accuracy)
        #print('\n[Global Model]  Test Accuracy: {:.2f}%\n'.format(test_accuracy * 100.))
    
    def global_weight(self):
        weights = self.model.state_dict()
        return weights

    def test(self):
        test_correct = 0
        self.model.eval()
        device = self.device
        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = Variable(data).to(device), Variable(target).to(device)
                output = self.model(data)
                pred = output.argmax(dim=1, keepdim=True)
                test_correct += pred.eq(target.view_as(pred)).sum().item()
        return test_correct / len(self.test_loader.dataset)
    


In [4]:
## Hier_Client.py
import torch

from torch import nn
from torch.autograd import Variable
class Client(object):
    def __init__(self, rank, data_loader, local_epoch, ES):
        seed = 19201077 + 19950920 + rank
        torch.manual_seed(seed)
        self.accuracy = []
        self.rank = rank
        self.local_epoch = local_epoch
        self.ES=ES
        self.test_loader = data_loader[1]
        self.train_loader = iter(data_loader[0])

    def load_global_model(self):
        model = CNN().to(ES.device)
        model.load_state_dict(self.ES.model.state_dict()) 
        return model
    
    ## Option
    # def __test(self, model):
    #     test_correct = 0
    #     model.eval()
    #     device = self.device
    #     with torch.no_grad():
    #         for data, target in self.test_loader:
    #             data, target = Variable(data).to(device),Variable(target).to(device)
    #             output = model(data)
    #             pred = output.argmax(dim=1, keepdim=True)
    #             test_correct += pred.eq(target.view_as(pred)).sum().item()
    #     return test_correct / len(self.test_loader.dataset)

    def __train(self, model):
        device = self.device
        model.train()
        optimizer = optim.SGD(model.parameters(), lr=1e-2)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                        lr_lambda=lambda epoch: 0.95 ** epoch,
                                        last_epoch=-1,
                                        verbose=False) 
        chk=0
        for data, target in self.train_loader:
            chk+=1
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            self.loss = nn.CrossEntropyLoss()(output, target)
            self.loss.backward()
            optimizer.step()    
            if(chk ==1):
                break
            scheduler.step()
        weights=model.state_dict()
        
        return weights

    def run(self):
        model = self.__load_global_model(self)
        for _ in range(self.local_epoch):
            weights = self.train(model=model)
   
        self.ES.clients[self.ES.count%self.ES.size]=weights
        self.ES.count+=1
        


In [5]:
import random
def fed_AVG(n_client, n_ES, ES_epoch, epoch, batch_size, device = "Cuda:0" ,type = "IID"):
    print('Initialize Dataset...')
    data_loader = loader('mnist', batch_size=batch_size, type = type)    
    print('Initialize Edge Servers and Clients...')
    ESs = []
    clients = [[ None for i in range(n_client)] for j in range(n_ES) ]
    for i in range(n_ES):
        ESs.append(ES(size = n_client, data_loader = data_loader.get_loader([]), device = device))
        for j in range(n_client):
            clients[i][j]=Client(rank=j, data_loader=data_loader.get_loader(
            random.sample(range(0, 10), 4)
            ),local_epoch = epoch,
            ES = ESs[i])
 
    # federated learning

    for ESe in range(ES_epoch):
        print('\n================== Edge Server Epoch {:>3} =================='.format(ESe + 1))
        for ESn in range(n_ES):
            print("================= Edge Server :",ESn,"process =================")
            for c in clients[ESn]:
                c.run()
                
            ESs[ESn].aggregate()
      