# **Federated Learning**

#### **Workshop Objective:**

This workshop aims to train a federated learning model with 100 clients, using differential privacy techniques to ensure data security.

**Federated Learning** allows multiple devices (clients) to train a shared machine learning model while keeping their individual data decentralized. Instead of sending raw data to a central server, clients train the model locally and only share updates (model parameters) with the server.

**Differential Privacy** adds mathematical noise to these updates, ensuring that individual data points cannot be inferred, even by a determined adversary. This provides a strong layer of privacy protection.

## Installing Required Libraries:

* Install essential libraries like `torch` for model development and opacus for differential privacy integration.
* `Opacus` is a library that integrates seamlessly with PyTorch to enable differential privacy in machine learning models, ensuring the training process adheres to privacy-preserving techniques.
* We will also install other necessary libraries to support federated learning and data handling.

In [50]:
import random
import copy
from datetime import date
import time
import os
import matplotlib.pyplot as plt
from collections import OrderedDict

import numpy as np
import torch
import torchmetrics
from torch import nn, tanh
from torch.nn.functional import relu, softmax, max_pool2d

import torchvision.transforms as transforms
from torchvision.datasets import MNIST, FashionMNIST

from collections import defaultdict

import opacus
from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager

## Defining Parameters and Variables:

Set up all the necessary parameters for loading, training, and managing the models. These include learning rates, batch sizes, privacy budgets, and other configurations for federated learning and differential privacy.

In [46]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

DATA_NAME = 'mnist'
root = './data'
NUM_CLIENTS = 100
BATCH_SIZE = 64
NUM_CLASES_PER_CLIENT = 10 
NUM_CLASSES = 10
LEARNING_RATE_DIS = 2e-1
EPOCHS = 1
ROUNDS = 80
sample_rate = 1
MODE = "LDP"
target_epsilon = 8
mp_bs = 64
target_delta = 1e-3

user_param = {'disc_lr': LEARNING_RATE_DIS, 'epochs': EPOCHS} 
user_param['rounds'] = ROUNDS
user_param['target_epsilon'] = target_epsilon
user_param['target_delta'] = target_delta
user_param['sr'] = sample_rate
user_param['mp_bs'] = mp_bs

server_param = {}

## Loading the Dataset:

* Load the widely-used `MNIST10 dataset`, which consists of grayscale images of handwritten digits (0-9). Each image is 28x28 pixels and belongs to one of 10 classes.
* Implement helper functions to divide the dataset among the 100 clients, ensuring each client receives a nearly equal number of samples from each class, to maintain balanced data distribution.

In [3]:
def get_datasets(data_name, dataroot, preprocess = None): 
    if data_name == 'mnist':
        normalization = transforms.Normalize((0.5,), (0.5,))
        transform = transforms.Compose([transforms.ToTensor(), normalization])
        data_obj = MNIST
    elif data_name == 'fashionmnist':
        normalization = transforms.Normalize((0.5,), (0.5,))
        transform = transforms.Compose([transforms.ToTensor(),  normalization])
        data_obj = FashionMNIST
    else:
        raise ValueError("choose data_name from ['mnist', 'fashionmnist']")


    train_set = data_obj(dataroot, train=True, transform=transform, download=True)
    test_set = data_obj(dataroot, train=False, transform=transform)
    return train_set, test_set

In [4]:
def get_num_classes_samples(dataset):
    """
    extracts info about certain dataset
    :param dataset: pytorch dataset object
    :return: dataset info number of classes, number of samples, list of labels
    """
    # ---------------#
    # Extract labels #
    # ---------------#
    if isinstance(dataset, torch.utils.data.Subset):
        if isinstance(dataset.dataset.targets, list):
            data_labels_list = np.array(dataset.dataset.targets)[dataset.indices]
        else:
            data_labels_list = dataset.dataset.targets[dataset.indices]
    else:
        if isinstance(dataset.targets, list):
            data_labels_list = np.array(dataset.targets)
        else:
            data_labels_list = dataset.targets
    classes, num_samples = np.unique(data_labels_list, return_counts=True)
    num_classes = len(classes)
    return num_classes, num_samples, data_labels_list

In [5]:
def gen_classes_per_node(dataset, num_users, classes_per_user=2, high_prob=0.6, low_prob=0.4):
    """
    creates the data distribution of each client
    :param dataset: pytorch dataset object
    :param num_users: number of clients
    :param classes_per_user: number of classes assigned to each client
    :param high_prob: highest prob sampled
    :param low_prob: lowest prob sampled
    :return: dictionary mapping between classes and proportions, each entry refers to other client
    """
    num_classes, num_samples, _ = get_num_classes_samples(dataset)

    # -------------------------------------------#
    # Divide classes + num samples for each user #
    # -------------------------------------------#
    # print(num_classes)
    assert (classes_per_user * num_users) % num_classes == 0, "equal classes appearance is needed"
    count_per_class = (classes_per_user * num_users) // num_classes
    class_dict = {}
    for i in range(num_classes):
        probs=np.array([1]*count_per_class)
        probs_norm = (probs / probs.sum()).tolist()
        class_dict[i] = {'count': count_per_class, 'prob': probs_norm}
    # -------------------------------------#
    # Assign each client with data indexes #
    # -------------------------------------#
    class_partitions = defaultdict(list)
    for i in range(num_users):
        c = []
        for _ in range(classes_per_user):
            class_counts = [class_dict[i]['count'] for i in range(num_classes)]
            max_class_counts = np.where(np.array(class_counts) == max(class_counts))[0]
            max_class_counts = np.setdiff1d(max_class_counts, np.array(c))
            c.append(np.random.choice(max_class_counts))
            class_dict[c[-1]]['count'] -= 1
        class_partitions['class'].append(c)
        class_partitions['prob'].append([class_dict[i]['prob'].pop() for i in c])
    return class_partitions

In [6]:
def gen_classes_per_node(dataset, num_users, classes_per_user=2, high_prob=0.6, low_prob=0.4):
    """
    creates the data distribution of each client
    :param dataset: pytorch dataset object
    :param num_users: number of clients
    :param classes_per_user: number of classes assigned to each client
    :param high_prob: highest prob sampled
    :param low_prob: lowest prob sampled
    :return: dictionary mapping between classes and proportions, each entry refers to other client
    """
    num_classes, num_samples, _ = get_num_classes_samples(dataset)

    # -------------------------------------------#
    # Divide classes + num samples for each user #
    # -------------------------------------------#
    # print(num_classes)
    assert (classes_per_user * num_users) % num_classes == 0, "equal classes appearance is needed"
    count_per_class = (classes_per_user * num_users) // num_classes
    class_dict = {}
    for i in range(num_classes):
        probs=np.array([1]*count_per_class)
        probs_norm = (probs / probs.sum()).tolist()
        class_dict[i] = {'count': count_per_class, 'prob': probs_norm}
    # -------------------------------------#
    # Assign each client with data indexes #
    # -------------------------------------#
    class_partitions = defaultdict(list)
    for i in range(num_users):
        c = []
        for _ in range(classes_per_user):
            class_counts = [class_dict[i]['count'] for i in range(num_classes)]
            max_class_counts = np.where(np.array(class_counts) == max(class_counts))[0]
            max_class_counts = np.setdiff1d(max_class_counts, np.array(c))
            c.append(np.random.choice(max_class_counts))
            class_dict[c[-1]]['count'] -= 1
        class_partitions['class'].append(c)
        class_partitions['prob'].append([class_dict[i]['prob'].pop() for i in c])
    return class_partitions

In [7]:
def gen_data_split(dataset, num_users, class_partitions):
    """
    divide data indexes for each client based on class_partition
    :param dataset: pytorch dataset object (train/val/test)
    :param num_users: number of clients
    :param class_partitions: proportion of classes per client
    :return: dictionary mapping client to its indexes
    """
    num_classes, num_samples, data_labels_list = get_num_classes_samples(dataset)

    # -------------------------- #
    # Create class index mapping #
    # -------------------------- #
    data_class_idx = {i: np.where(data_labels_list == i)[0] for i in range(num_classes)}

    # --------- #
    # Shuffling #
    # --------- #
    for data_idx in data_class_idx.values():
        random.shuffle(data_idx)

    # ------------------------------ #
    # Assigning samples to each user #
    # ------------------------------ #
    user_data_idx = [[] for i in range(num_users)]
    for usr_i in range(num_users):
        for c, p in zip(class_partitions['class'][usr_i], class_partitions['prob'][usr_i]):
            end_idx = int(num_samples[c] * p)
            user_data_idx[usr_i].extend(data_class_idx[c][:end_idx])
            data_class_idx[c] = data_class_idx[c][end_idx:]
        if len(user_data_idx[usr_i])%2 == 1: user_data_idx[usr_i] = user_data_idx[usr_i][:-1]

    return user_data_idx

In [8]:
def gen_random_loaders(data_name, data_path, num_users, bz, num_classes_per_user, num_classes, preprocess=None):
    """
    generates train/val/test loaders of each client
    :param data_name: name of dataset, choose from [mnist10, fashionmnist, chmnist]
    :param data_path: root path for data dir
    :param num_users: number of clients
    :param bz: batch size
    :param classes_per_user: number of classes assigned to each client
    :return: train/val/test loaders of each client, list of pytorch dataloaders
    """
    loader_params = {"batch_size": bz, "shuffle": False, "pin_memory": True, "num_workers": 0}
    dataloaders = []
    datasets = get_datasets(data_name, data_path, preprocess=preprocess)
    # print(datasets)
    cls_partitions = None
    distribution = np.zeros((num_users, num_classes))
    for i, d in enumerate(datasets):
        if i == 0:
            cls_partitions = gen_classes_per_node(d, num_users, num_classes_per_user)
            # print(cls_partitions)
            for index in range(num_users):
                distribution[index][cls_partitions['class'][index]] = cls_partitions['prob'][index]

            loader_params['shuffle'] = True
        usr_subset_idx = gen_data_split(d, num_users, cls_partitions)

        subsets = list(map(lambda x: torch.utils.data.Subset(d, x), usr_subset_idx))
        dataloaders.append(list(map(lambda x: torch.utils.data.DataLoader(x, **loader_params), subsets)))

    return dataloaders

### Saving Training and Test Data:

* Preprocess and save both the training and testing data, ensuring they are correctly partitioned for client usage.

In [9]:
# we get the training data and test data for each user right now. 
train_dataloaders, test_dataloaders  = gen_random_loaders(DATA_NAME, root, NUM_CLIENTS, BATCH_SIZE, NUM_CLASES_PER_CLIENT, NUM_CLASSES)

In [45]:
def saveClientData(NUM_CLIENTS, dataset, folder_name, train = True):

    if train:
        train = 'train'
    else: 
        train = 'test'

    for i in range(NUM_CLIENTS):
        data_directory = f'./{folder_name}/client_{train}_data'
        os.makedirs(data_directory, exist_ok = True)

        user_batch_data = []
        user_batch_labels = []
        for img, lab in dataset[i]:
            user_batch_data.append(img)
            user_batch_labels.append(lab)

        user_data_tensor = torch.cat(user_batch_data, dim=0)
        user_labels_tensor = torch.cat(user_batch_labels, dim=0)

        # Save the concatenated data and labels to a single file
        torch.save({'images': user_data_tensor, 'labels': user_labels_tensor}, f'{data_directory}/client_{train}_{i:02}.pt')


def saveWeights(users, folder_name):
    weight_directory = f'./{folder_name}/client_model_weights'
    os.makedirs(weight_directory, exist_ok=True)
    for i in range(len(users)):
        # print('saving the weights of users:')
        torch.save(users[i].get_model_state_dict(), f"{weight_directory}/weight_user{i:02}.pth")

def saveModels(users, folder_name):
    model_directory = f'./{folder_name}/client_model'
    os.makedirs(model_directory, exist_ok = True)
    for i in range(len(users)):
        # print('saving the model of users:')
        torch.save(users[i].model.state_dict(), f"{model_directory}/model_user{i:02}.pth")


folder_name = 'FL_LDP_data'
saveClientData(NUM_CLIENTS, train_dataloaders, folder_name, train = True)
saveClientData(NUM_CLIENTS, test_dataloaders, folder_name, train = False)

## Building the Main Model:

* Implement the neural network architecture for training on both client and server sides. This model will be initialized and trained by each client individually.

In [12]:
# MODELS 
class mnist_fully_connected(nn.Module):
    def __init__(self,num_classes):
        super(mnist_fully_connected, self).__init__()
        self.hidden1 = 600
        self.hidden2 = 100
        self.fc1 = nn.Linear(28 * 28, self.hidden1, bias=False)
        self.relu_ = nn.ReLU(inplace=False)
        self.fc2 = nn.Linear(self.hidden1, self.hidden2, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(self.hidden2, num_classes, bias=False)
        
    def forward(self,x, return_probs=True):
        x = x.view(-1, 28 * 28)
        x = self.relu_(self.fc1(x))
        x = relu(self.fc2(x))
        logits = self.fc3(x)
        if return_probs:
            return logits, softmax(logits, dim = 1)
        else:
            return logits


## Building Client and Server-Side Classes:

* Define classes for the client and server models. Each client will have a local instance of the model, and the server will maintain the central model.

In [25]:
class LDPUser:
    def __init__(self, index, device, model, n_classes, input_shape, train_dataloader, epochs, rounds, 
                 target_epsilon, target_delta, sr, max_norm=2.0, disc_lr=5e-1, mp_bs = 3):
        self.index = index
        self.rounds = rounds
        self.target_epsilon = target_epsilon
        self.epsilon = 0
        self.delta = target_delta
        self.model = model(num_classes=n_classes)
        self.model = ModuleValidator.fix(self.model)
        self.train_dataloader = train_dataloader
        self.sr = sr
        self.mp_bs = mp_bs
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.disc_lr = disc_lr
        self.acc_metric = torchmetrics.Accuracy().to(device)
        self.device = device
        self.max_norm= max_norm
        self.epochs = epochs
        self.optim = torch.optim.SGD(self.model.parameters(), self.disc_lr)
        self.make_local_private()
        


    def make_local_private(self):
        self.privacy_engine = opacus.PrivacyEngine()
        self.model, self.optim, self.train_dataloader = self.privacy_engine.make_private_with_epsilon(module=self.model, optimizer=self.optim,
                                                                                                      data_loader=self.train_dataloader, epochs=self.epochs*self.rounds*self.sr,
                                                                                                      target_epsilon=self.target_epsilon, target_delta=self.delta,
                                                                                                      max_grad_norm=self.max_norm)

    def train(self):
        self.model = self.model.to(self.device)
        self.model.train()
        for epoch in range(self.epochs):
            with BatchMemoryManager(data_loader=self.train_dataloader, max_physical_batch_size=self.mp_bs, optimizer=self.optim) as batch_loader:
                for images, labels in batch_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    self.optim.zero_grad()
                    logits, preds = self.model(images, return_probs=True)
                    loss = self.loss_fn(logits, labels)
                    loss.backward()
                    self.optim.step()
                    self.acc_metric(preds, labels)
        self.epsilon = self.privacy_engine.get_epsilon(self.delta)
        print(f"Client: {self.index} ACC: {self.acc_metric.compute()}, episilon: {self.epsilon}")
        self.acc_metric.reset()
        self.model.to('cpu')

    def evaluate(self, dataloader):
        self.model.to(self.device)
        self.model.eval()
        testing_corrects = 0
        testing_sum = 0
        with torch.no_grad():
            for images, labels in dataloader:
                images, labels = images.to(self.device), labels.to(self.device)
                _, preds = self.model(images, return_probs=True)
                testing_corrects += torch.sum(torch.argmax(preds, dim=1) == labels)
                testing_sum += len(labels)
        self.model.to('cpu')
        return testing_corrects.cpu().detach().numpy(), testing_sum

    def get_model_state_dict(self):
        return self.model.state_dict()

    def set_model_state_dict(self, weights):
        for key, value in self.model.state_dict().items():
            if 'bn' not in key:
                self.model.state_dict()[key].data.copy_(weights[key])

In [14]:
class LDPServer:
    def __init__(self, device, model, n_classes, input_shape, noise_multiplier=1, sample_clients=10, disc_lr=1):
        self.model = model(num_classes=n_classes)
        self.model = ModuleValidator.fix(self.model)
        self.privacy_engine = opacus.PrivacyEngine()
        self.model = self.privacy_engine._prepare_model(self.model)
        self.device = device
        self.noise_multiplier = noise_multiplier
        self.sample_clients = sample_clients 
        self.disc_lr = disc_lr

    def get_model_state_dict(self):
        return self.model.state_dict()

# Get the average weight of the client models.
def agg_weights(weights):
    with torch.no_grad():
        weights_avg = copy.deepcopy(weights[0])
        for k in weights_avg.keys():
            for i in range(1, len(weights)):
                weights_avg[k] += weights[i][k]
            weights_avg[k] = torch.div(weights_avg[k], len(weights))
    return weights_avg

def evaluate_global(users, test_dataloders, users_index):
    testing_corrects = 0
    testing_sum = 0
    for index in users_index:
        corrects, num = users[index].evaluate(test_dataloders[index])
        testing_corrects += corrects
        testing_sum += num
    print(f"Acc: {testing_corrects / testing_sum}")
    return (testing_corrects / testing_sum)

## Passing Weights from Server to Clients:

* Synchronize the models by passing the initial server-side model weights to all clients, ensuring all clients start training with the same model structure.



In [None]:
user_obj = LDPUser
server_obj = LDPServer
MODEL = mnist_fully_connected
server_obj = LDPServer
server = server_obj(device, MODEL, NUM_CLASSES, None, **server_param)
users = [user_obj(i, device, MODEL, NUM_CLASSES, None, train_dataloaders[i], **user_param) for i in range(NUM_CLIENTS)]

for i in range(NUM_CLIENTS):
    users[i].set_model_state_dict(server.get_model_state_dict())

## Training and Aggregation Loop:

* Train the models on the client side and periodically send the updated model parameters back to the server.
* The server will aggregate these parameters from all clients (e.g., using weighted averaging) and send the updated global model back to the clients to continue the loop.

In [None]:
best_acc = 0
for round in range(ROUNDS): # Changed to one for practice. ROUND

    random_index = np.random.choice(NUM_CLIENTS, int(sample_rate*NUM_CLIENTS), replace=False)
    for index in random_index:
        users[index].train() # training the user for each round.
    # for index in random_index:users[index].set_model_state_dict2(torch.load(client_weights[index]))

    # Saving the models and the weight in the last round
    if round == ROUNDS - 1:
        saveModels(users, folder_name)
        saveWeights(users, folder_name)
    
    if MODE == "LDP":
        weights_agg = agg_weights([users[index].get_model_state_dict() for index in random_index])
        for i in range(NUM_CLIENTS):
            users[i].set_model_state_dict(weights_agg)
    else:
        server.agg_updates([users[index].get_model_state_dict() for index in random_index])
        for i in range(NUM_CLIENTS):
            users[i].set_model_state_dict(server.get_model_state_dict())

    print(f"Round: {round+1}")
    acc = evaluate_global(users, test_dataloaders, range(NUM_CLIENTS))
    if acc > best_acc:
        best_acc = acc
    if MODE == "LDP":
        eps = max([user.epsilon for user in users])
        print(f"Epsilon: {eps}")
        if eps > target_epsilon:
            saveModels(users, folder_name)
            saveWeights(users, folder_name)
            break

print('Federated Learning Client Training Finished')

## Visualizing the Model and Data:

* After training, visualize the saved model and data for one of the clients to better understand the learning process. This will provide insights into the accuracy, privacy effects, and model behavior.

In [None]:
training_data = torch.load('./FL_LDP_data/client_train_data/client_train_00.pt')
testing_data = torch.load('./FL_LDP_data/client_test_data/client_test_00.pt')

x_train = training_data['images']
y_train = training_data['labels']
x_test = testing_data['images']
y_test = testing_data['labels']
# convert the labels to one_hot_encoding
num_classes = len(torch.unique(y_train))

print(f"Train data shape: {x_train.shape}, training labels shape: {y_train.shape}")
print(f"Test data shape: {x_test.shape}, test labels shape: {y_test.shape}")

In [52]:
def strip_prefix(state_dict, prefix="_module."):
    """
    Strip a prefix from the state_dict keys.
    Args:
        state_dict (dict): The state_dict with the potentially prefixed keys.
        prefix (str): The prefix to remove.
    Returns:
        dict: The state_dict with the prefix removed from the keys.
    """
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith(prefix):
            new_state_dict[k[len(prefix):]] = v
        else:
            new_state_dict[k] = v
    return new_state_dict

In [None]:
model_path = './FL_LDP_data/client_model_weights/weight_user00.pth'  # Path to your saved model weights
model = mnist_fully_connected(num_classes)

# Load the state_dict
state_dict = torch.load(model_path)
# Strip the "_module." prefix if it exists
state_dict = strip_prefix(state_dict, prefix="_module.")
# Load the modified state_dict into the model
model.load_state_dict(state_dict)
# Set the model to evaluation mode

# Check the models accuracy on the test data
with torch.no_grad():
    test_logits, _ = model(x_test)
    
prediction = torch.sum(torch.argmax(test_logits, axis = 1) == y_test) / len(y_test)
print('The accuray on the test data is : ', prediction.numpy() * 100)

#### Visualize the train data and labels

In [None]:
fig, ax = plt.subplots(2,5, figsize=(15, 7))
ax = ax.flatten()

for i, (image, label) in enumerate(zip(x_train[:10], y_train[:10])):
  img = image.permute(1, 2, 0)
  ax[i].imshow(img, cmap = "Greys")
  ax[i].set_title(f"Label: {label}",  fontsize=12)
  ax[i].set_xticks([])
  ax[i].set_yticks([])

#### Visualize the test data and labels

In [None]:
fig, ax = plt.subplots(2,5, figsize=(15, 7))
ax = ax.flatten()

for i, (image, label) in enumerate(zip(x_test[:10], y_test[:10])):
  img = image.permute(1, 2, 0)
  ax[i].imshow(img, cmap = "Greys")
  ax[i].set_title(f"Label: {label}",  fontsize=12)
  ax[i].set_xticks([])
  ax[i].set_yticks([])

### Summary:

This workshop focuses on training a federated learning model with 100 clients using differential privacy to ensure data security. Participants will install necessary libraries like PyTorch and Opacus, divide the MNIST dataset among clients, build client and server models, and perform privacy-preserving training. The process includes model synchronization, aggregation, and visualization of the results to demonstrate privacy and model performance.