# Federated Learning Simulator

In [1]:
import copy
import os
import numpy as np
import random  

from torchvision import datasets, transforms

from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import MinMaxScaler


In [None]:
# Load Data (CIFAR-10)
'''
CIFAR-10 : 32*32*3, train 50k test 10k
MNIST: 28*28, train 60k test 10k
'''

cifar10_train = datasets.CIFAR10(root='./data', train=True, download = True)
cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)

mnist_train = datasets.MNIST(root='./data', train=True, download = True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True)



In [3]:
# Configurations
'''
---------------- Simulation Parameters -----------------------
NUM_CLIENTS: The number of clients used in simulation
NUM_ITER: Iteration number of simulation
LEN_PER_ITER: A number of training data used in each client (per iteration) 

DATASET_SELECT: Select dataset. 'cifar-10', 'mnist'
**** NUM_CLIENTS * NUM_ITER * LEN_PER_ITER should be less than the size of training dataset.
'''
NUM_CLIENTS = 100
NUM_SERVER = 1
NUM_ITER = 4
LEN_PER_ITER = 100

TESTSET_SIZE = 300
DATASET_SELECT = 'cifar-10'
'''
---------------- Print Options -----------------------
PRINT_PACKET_MESSAGE: 
    True: Packet messages(weights, intercepts, etc.) be printed.
    False: Only sender & recipient be printed.
PRINT_SIMULATION_LOG : 
    True: Print simulation logs
    False: Print final results only 
PRINT_PACKET_TRANSFER:
    True: Print all of packets transfer logs
    False: Skip
PRINT_INDIVIDUAL_RESULT:
    True: Print the result of each client
    False: Print averaged result only
'''
PRINT_PACKET_MESSAGE = False
PRINT_SIMULATION_LOG = True
PRINT_DETAILS = False
PRINT_INDIVIDUAL_RESULT = True

'''
---------------- CLIENTS SETTING -----------------------
CLIENT_NAMES: A List of client names. 'client0', 'client1', ... 
'''
CLIENT_NAMES = ['client'+ str(i) for i in range(NUM_CLIENTS)]



'''
---------------- DROPOUT SETTING -----------------------
USE_DROPOUT : True if including dropout situation in simulation
DROPOUT_RATIO: Ratio of clients which choose to disconnect in each cycle.

'''
USE_DROPOUT = False
DROPOUT_RATIO = 0.1

'''
---------------- DATASET  -----------------------
CIFAR10_TRAIN: cifar-10 trainset size
CIFAR10_TEST: cifar-10 testset size
MNIST_TRAIN: mnist trainset size
MNIST_TEST: mnist testset size


'''

CIFAR10_TRAIN = 50000
CIFAR10_TEST = 10000
MNIST_TRAIN = 60000
MNIST_TEST = 10000


'''
---------------- SECURITY OPTION -----------------------
SECURITY_TYPE: Select secure aggregation algorithm.
    Options: 'no_security', 'additive_mask', ...

'''
SECURITY_TYPE = 'no_security'
WEIGHT_SIZE = None
INTERCEPT_SIZE = None


In [15]:
# Modules & Utils
'''
Packet: Used in communication between clients & server.
Network: Network connects clients and server used in simulation.
Evaluator: To evaluate each model on each iteration.

partition: Make local dataset of each client from trainset.
dataset: choose dataset
transforms: transforms data shapes.
'''

class Packet:
    def __init__(self, sender, recipient, message):
        self.sender = sender
        self.recipient = recipient
        self.message = message

    def __str__(self):
        if PRINT_PACKET_MESSAGE: # Controlled in configurations
            return "Sender: {}, Recipient: {}, Message: {}\n".format(self.sender, self.recipient, self.message)
        else:
            return "Sender: {}, Recipient: {}, Message: SKIP\n".format(self.sender, self.recipient)


class Network:
    def __init__(self, clients, server):
        self.clients = clients
        self.server = server
        self.active_clients = {} # Active clients (except dropouts) in current iteration.
        self.buffer = {} # Will be used as a packet buffer in multi-server condition.

    def connect(self, client):
        self.active_clients[client.name] = client

    def disconnect(self, client):
        del self.active_clients[client.name]

    def num_actives(self):
        n = len(self.active_clients)
        if PRINT_SIMULATION_LOG:
            print('Remaining clients: {} \n'.format(n))
            if PRINT_DETAILS:
                print(list(self.active_clients.keys()))
        return n

 
        
class Evaluator:
    def __init__(self, X_test, y_test):
        self.X = X_test[:300]
        self.y = y_test[:300]
        self.model = LogisticRegression()
        self.model.fit(self.X, self.y)
    
    def evaluate(self, W, b):
        self.model.coef_ = W
        self.model.intercept_ = b
        acc = self.model.score(self.X, self.y)
        return acc


def partition(X, y, client_names, num_iter, len_per_iter):
    local_datasets = {client_name: None for client_name in client_names}
    last_idx = 0 # where to start next
    for client_name in client_names:
        dataset = {}
        length = len_per_iter
        start_idx = last_idx
        last_idx = start_idx + num_iter * length

        for j in range(1, num_iter+1):
            end_idx = start_idx + length 

            Xj = X[start_idx:end_idx]
            yj = y[start_idx:end_idx]

            dataset[j] = (Xj, yj)

            start_idx = end_idx
            
        local_datasets[client_name] = dataset

    return local_datasets        



def dataset():
    global WEIGHT_SIZE
    global INTERCEPT_SIZE
    if DATASET_SELECT == 'cifar-10':
        WEIGHT_SIZE = 10*32*32*3
        INTERCEPT_SIZE = 10
        trainset, testset = cifar10_train, cifar10_test
        return trainset, testset
    elif DATASET_SELECT == 'mnist':
        WEIGHT_SIZE = (10*28*28)
        INTERCEPT_SIZE = 10
        trainset, testset = mnist_train, mnist_test
        return trainset, testset
    else:
        raise (ValueError('Unexpected dataset selection.\n'))

def transforms(traindata, testdata):
    if DATASET_SELECT == 'cifar-10':
        return traindata.reshape(CIFAR10_TRAIN, -1), testdata.reshape(CIFAR10_TEST, -1)
    elif DATASET_SELECT == 'mnist':
        return traindata.reshape(MNIST_TRAIN, -1), testdata.reshape(MNIST_TEST, -1)
    else:
        raise (ValueError('Unexpected dataset selection.\n'))

In [5]:
# Secure Aggregation
'''
TODO: Secure Aggregation Algorithms

'''

class Security:
    def __init__(self):
        self.security = None
        self.num_clients = NUM_CLIENTS
        self.W_size = WEIGHT_SIZE
        self.b_size = INTERCEPT_SIZE
        self.security_offset = {}
        self.pair = {client_name: [] for client_name in CLIENT_NAMES}
        #else(random keys, etc)

    def set_security(self):
        if SECURITY_TYPE == 'no_security':
            self.no_security()
        elif SECURITY_TYPE == 'additive_mask':
            self.additive_masking()
        else:
            raise (ValueError('Unexpected Security Type.\n'))

    def no_security(self):
        W0 = np.zeros(WEIGHT_SIZE)
        b0 = np.zeros(INTERCEPT_SIZE)
        W0 = W0.reshape(INTERCEPT_SIZE, -1)
        self.security_offset = {client_name: (W0, b0) for client_name in CLIENT_NAMES}
        
    def additive_masking(self):
        W_offset = []
        b_offset = []
        num_clients = self.num_clients
        for i in range(num_clients-1):
            Wi = np.random.rand(WEIGHT_SIZE)
            bi = np.random.rand(INTERCEPT_SIZE)
            Wi = Wi.reshape(INTERCEPT_SIZE, -1)
            W_offset.append(Wi)
            b_offset.append(bi)
        Wn = - np.sum(W_offset, axis=0)
        bn = - np.sum(b_offset, axis=0)
        W_offset.append(Wn)
        b_offset.append(bn)

        for i in range(num_clients):
            name = 'client' + str(i)
            self.security_offset[name] = 100*W_offset[i], 100*b_offset[i]

    def pair_additive_masking(self):
        W_offset = []
        b_offset = []
        num_clients = self.num_clients
        
        if (num_clients % 2 == 0):
            tmp = num_clientss
        else:
            tmp = num_clients - 3

        for i in range(0, tmp, 2):
            Wi = np.random.rand(WEIGHT_SIZE)
            Wi = Wi.reshape(INTERCEPT_SIZE, -1)
            bi = np.random.rand(INTERCEPT_SIZE)
            name1 = 'client' + str(i)
            name2 = 'client' + str(i+1)
            self.security_offset[name1] = 100 * Wi, 100 * bi
            self.security_offset[name2] = -100 * Wi, -100 * bi
            self.pair[name1].append(name2)
            self.pair[name2].append(name1)

        if (num_clients % 2 == 1):
            Wi = np.random.rand(WEIGHT_SIZE)
            Wi = Wi.reshape(INTERCEPT_SIZE, -1)
            bi = np.random.rand(INTERCEPT_SIZE)

            Wj = np.random.rand(WEIGHT_SIZE)
            Wj = Wi.reshape(INTERCEPT_SIZE, -1)
            bj = np.random.rand(INTERCEPT_SIZE)

            name1 = 'client' + str(num_clients - 3)
            name2 = 'client' + str(num_clients - 2)
            name3 = 'client' + str(num_clients - 1)

            self.security_offset[name1] = 100 * Wi, 100 * bi
            self.security_offset[name2] = 100 * Wj, 100 * bj
            self.security_offset[name3] = -100 * (Wi + Wj), -100 * (bi + bj)

            self.pair[name1].append([name2,name3])
            self.pair[name2].append([name1,name3])
            self.pair[name3].append([name1,name2])


    def secure_aggregation(self, name, W, b):
        W_offset, b_offset = self.security_offset[name]
        W_secure = W + W_offset
        b_secure = b + b_offset

        return W_secure, b_secure

 


In [18]:
# CLIENT
class Client:
    def __init__(self, name, local_data, evaluator, security):
        # Device info
        self.name = name
        self.network = None
        self.security = security
        self.active = None
        self.dropped_iter = NUM_ITER

        # Training data & evaluator
        self.local_data = local_data
        self.evaluator = evaluator

        # Weights, intercepts and accuracy
        self.W_local = {}
        self.b_local = {}
        self.acc_local = {}

        self.W_federated = {}
        self.b_federated = {}
        self.acc_federated = {}

        self.W_secured = {} # To check the secure aggregation effect (not required for the system)
        self.b_secured = {}
        self.acc_secured = {}



    def connect(self, network):
        self.network = network
        self.active = True
        network.connect(self)
        #if PRINT_SIMULATION_LOG:
        #   print('{} connected. \n'.format(self.name)) #log

    def disconnect(self):
        self.active = False
        self.network.disconnect(self)

        if PRINT_SIMULATION_LOG:
            print('{} disconnected. \n'.format(self.name)) #log

    def train(self, iter):
        X, y = self.local_data[iter]
        model = SGDClassifier(alpha = 0.0001, loss = "log") 
        W_federated = None
        b_federated = None
        if iter > 1: # if not first iteration, uses federated params from server
            W_federated = copy.deepcopy(self.W_federated[iter-1])
            b_federtaed = copy.deepcopy(self.b_federated[iter-1]) 
        
        model.fit(X, y, coef_init=W_federated, intercept_init=b_federated)

        # Local params of this iteration
        W_local = model.coef_
        b_local = model.intercept_


        return W_local, b_local

    def send(self, iter):
        if (iter-1) > len(self.local_data):
            raise (ValueError('Not enough data in iteration # {}.'.format(iter)))

        # Client tasks start

        W_local, b_local = self.train(iter)

        self.W_local[iter] = W_local
        self.b_local[iter] = b_local

        W, b = copy.deepcopy(W_local), copy.deepcopy(b_local) # To keep data when using security algorithm
   
        # Secure Aggregation
        W, b = self.security.secure_aggregation(self.name, W, b)

        # Client tasks end

        self.W_secured[iter] = W
        self.b_secured[iter] = b

        # Local Evaluation
        acc = self.evaluator.evaluate(W_local, b_local)
        self.acc_local[iter] = acc
        acc_sec = self.evaluator.evaluate(W, b)
        self.acc_secured[iter] = acc_sec

        # Send message to server
        message = {'iter': iter, 'weights' : W, 'intercepts': b}
        packet = Packet(sender=self.name, recipient=self.network.server.name, message=message)

        # LOG
        if PRINT_SIMULATION_LOG:
            if PRINT_DETAILS:
                print('Packet sent from {} to server'.format(self.name))
                print(packet.__str__())

        return packet

        
    def receive(self, packet):
        # Received message
        message = packet.message
        iter, W, b = message['iter'], message['weights'], message['intercepts']
        W = copy.deepcopy(W)
        b = copy.deepcopy(b)

        # Save federated params in local memory
        self.W_federated[iter] = W
        self.b_federated[iter] = b

        # Federated Evaluation
        acc = self.evaluator.evaluate(W, b)
        self.acc_federated[iter] = acc

        # Dropout
        if USE_DROPOUT:
            self.dropout(iter)

        # LOG
        if PRINT_SIMULATION_LOG:
            if PRINT_DETAILS:
                print('Packet received from server in {}'.format(self.name))
                print(packet.__str__())
            

    def dropout(self, iter):
        p = random.uniform(0, 1)
        if (p < DROPOUT_RATIO):
            self.disconnect()
            self.dropped_iter = iter
  




In [7]:
# SERVER
class Server:
    def __init__(self, name):
        self.W_server = {}
        self.b_server = {}
        self.network = None
        self.name = name

    def connect(self, network):
        self.network = network
        if PRINT_SIMULATION_LOG:
            print('Server connected\n')

    def update(self, iter):
        
        for i in range(1, iter+1):
            W = {}
            b = {}
            packets = {}
            parameters = {}
            client_names = set(self.network.active_clients.keys())
            print('\nIteration {}\n'.format(i)) # print current iteration
            # Print Active Clients
            if PRINT_SIMULATION_LOG:
                self.network.num_actives()
            # Receiving packets from each client
            for client_name in client_names:
                client = self.network.active_clients[client_name]
                packet = client.send(i)
                packets[client_name] = packet

            # Server tasks started

            for client_name in client_names: # parameters = {clientX: (W,b)}
                packet = packets[client_name]
                weights, intercepts = packet.message['weights'], packet.message['intercepts']
                parameters[client_name] = (weights, intercepts)

            for client_name in client_names:
                W_client, b_client = parameters[client_name]
                W[client_name] = np.array(W_client)
                b[client_name] = np.array(b_client)

            W_received = list(W.values())
            b_received = list(b.values())

            self.W_server[i] = np.average(W_received, axis=0)
            self.b_server[i] = np.average(b_received, axis=0)

            # Server tasks end

            message = {'iter':i, 'weights':self.W_server[i], 'intercepts':self.b_server[i]} # returns to clients


            # Sending packets back to clients
            for client_name in client_names:
                client = self.network.active_clients[client_name]
                packet = Packet(sender=self.name, recipient=client_name, message=message)
                client.receive(packet)

    def final_results(self): # Average of active clients
        acc_clients = []
        acc_server = []
        acc_secured = []

        client_names = set(self.network.active_clients.keys())
        for client_name in client_names:
            client = self.network.clients[client_name]
            acc_clients.append(list(client.acc_local.values()))
            acc_server.append(list(client.acc_federated.values()))
            acc_secured.append(list(client.acc_secured.values()))

        acc_clients = list(np.mean(acc_clients, axis=0))
        acc_server = list(np.mean(acc_server, axis=0))
        acc_secured = list(np.mean(acc_secured, axis=0))

        print('\n\n--------- Simulation Results ----------\n')
        print('Local accuracy on each iter: {}\n'.format(acc_clients))
        print('Federated accuracy on each iter : {}\n'.format(acc_server))
        print('Accuracy test on secured W, b : {}\n'.format(acc_secured))
    
    def individual_results(self):
        client_names = list(self.network.clients.keys())


        print('\n\n--------- Individual Results ----------\n')
        for client_name in client_names:
            client = self.network.clients[client_name]
            if (client.dropped_iter < NUM_ITER):
                print('{} (dropped at iter {}) \n'.format(client_name, client.dropped_iter))
            else:
                print('{} \n'.format(client_name))
            print('Local Accuracy: {}\n'.format(list(client.acc_local.values())))
            print('Federated Accuracy: {}\n'.format(list(client.acc_federated.values())))
            print('Secured model Accuracy: {}\n'.format(list(client.acc_secured.values())))
            print('\n')

  

In [16]:
# Simulator
'''
Additional Tasks List
- print out the simulation results (server.result() function)
- Apply security methods in clients
- Check convergence of each client
- Client Dropout situation
- Send packets through network's buffer.
- etc
'''

class Simulator:
    def __init__(self, num_clients, iter):
        self.iter = iter
        self.clients = {}
        self.server = None

        # Train / Test dataset
        trainset, testset = dataset()

        X_train, y_train = trainset.data, trainset.targets
        X_test, y_test = testset.data, testset.targets

        X_train, X_test = transforms(X_train, X_test)

        # Min-Max scale (Scales 0 to 1)
        scaler = MinMaxScaler()
        scaler.fit(X_train)
        X_train = scaler.transform(X_train)
        scaler.fit(X_test)
        X_test = scaler.transform(X_test)

        # Data Partition
        len_per_iter = LEN_PER_ITER
        client_names = CLIENT_NAMES
        local_trainset = partition(X_train, y_train, client_names, iter, len_per_iter)

        # Run Utils

        # Evaluator & Security
        evaluator = Evaluator(X_test, y_test)
        security = Security()
        security.set_security()

        # Clients 
        self.clients = {client_name: Client(client_name, local_trainset[client_name], evaluator, security) for client_name in client_names}

        # Server
        self.server = Server(name='server')

        # Connect to Network
        self.network = Network(clients=self.clients, server=self.server)
        for client_name in client_names:
            client = self.clients[client_name]
            client.connect(self.network)
        server = self.server
        server.connect(self.network)


    def print_settings(self):
        print('----------Simulation settings----------')
        print('Clients: {}'.format(NUM_CLIENTS))
        print('Servers: {}'.format(NUM_SERVER))
        print('Total iteration: {}'.format(NUM_ITER))
        print('Dataset: {} '.format(DATASET_SELECT))
        print('Local trainset size per iteration: {}'.format(LEN_PER_ITER))
        if USE_DROPOUT:
            print('Dropout: ON, ratio {}'.format(DROPOUT_RATIO))
        else:
            print('Dropout: OFF')

    def run(self):
        self.server.update(iter=self.iter)

    def print_results(self):
        server = self.server
        server.final_results()
        if PRINT_INDIVIDUAL_RESULT:
            server.individual_results()


In [10]:
# Ignore convergence warning
from warnings import filterwarnings
filterwarnings('ignore')

In [11]:


# Train / Test dataset
trainset, testset = dataset()

X_train, y_train = trainset.data, trainset.targets
X_test, y_test = testset.data, testset.targets

X_train, X_test = transforms(X_train, X_test)

# Min-Max scale (Scales 0 to 1)
scaler = MinMaxScaler()

scaler.fit(X_test)
X_test = scaler.transform(X_test)



# Evaluator & Security
# evaluator = Evaluator(X_test, y_test)

In [12]:
test_model = LogisticRegression()
test_model.fit(X_test[:300], y_test[:300])

LogisticRegression()

In [13]:
test_acc = test_model.score(X_test[:300], y_test[:300])
print(test_acc)

1.0


In [None]:
# Run simulator
simulator = Simulator(num_clients = NUM_CLIENTS, iter = NUM_ITER)

simulator.print_settings()
simulator.run()
simulator.print_results()