<h1>The purpose of this notebook is to mimic the FedAvg Algorithm used in <i>Communication-Efficient Learning of Deep Networks from Decentralized Data</i>, produce similar results, and gain coding experience in Federated Learning concepts</h1> For HW, compare IID vs non IDD, and implement round robin style scheduling of clients compared to the "random" scheduling. If no change is seen then change the parameters to show the difference

In [1]:
# Import Global Dependencies
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
import torch

# Import Helper Libaries
import matplotlib.pyplot as plt
from torchinfo import summary
import numpy as np
import random
import copy

<h2>Data Preprocessing</h2>

<h3>Decentralize Dataset Function</h3>

In [2]:
class DecentralizeDataset(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)
        
    def __len__(self):
        return len(self.idxs)
    
    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label
        
normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

training_dataset = datasets.MNIST('../data/mnist/', train=True, download=True, transform=normalize)
validation_dataset = datasets.MNIST('../data/mnist/', train=False, download=True, transform=normalize)

<h3>IID Data</h3>
The data is shuffled and then divided up across 100 clients each receiving 600 examples

In [3]:
def IID(dataset, num_clients):
    dict_clients = {}
    # list of all indices in the dataset
    idxs = list(range(len(dataset)))
    # 600 data samples per client
    data_per_client = len(dataset) // num_clients 
    
    for i in range(num_clients):
        # Randomly select a subset of indices for the current client
        client_idxs = np.random.choice(idxs, size=data_per_client, replace=False)
        # Add the selected indices to the dictionary for the current client
        dict_clients[i] = set(client_idxs)
        # Remove the selected indices from the list of available indices
        idxs = list(set(idxs) - set(client_idxs)) 
    return dict_clients

<h3>Non-IID Data</h3>
The data is sorted by digit label, divided up into 200 'shards' of 300 examples, and then each client receieves 2 'shards'

In [4]:
def nonIID(dataset, num_clients):
    dict_clients = {}
    # List of all indices in the dataset
    idxs = list(range(len(dataset)))
    # 600 data samples per client
    data_per_client = len(dataset) // num_clients
    # Sort the dataset by label
    sorted_idxs = sorted(idxs, key=lambda i: dataset.targets[i])
    # Divide the sorted dataset into 200 shards of size 300
    shards = [sorted_idxs[i:i+300] for i in range(0, len(sorted_idxs), 300)]
    # Assign 2 shards to each client
    shards_per_client = 2
    for i in range(num_clients):
        # Randomly select 2 shards for the current client
        client_shards = np.random.choice(len(shards), size=shards_per_client, replace=False)
        client_idxs = []
        # Gather indices from selected shards
        for shard_idx in client_shards:
            client_idxs.extend(shards[shard_idx])
        # Add the selected indices to the dictionary for the current client
        dict_clients[i] = set(client_idxs)
    return dict_clients

<h2>CNN Model Declaration</h2>

In [5]:
class CNN(nn.Module):
    def __init__(self, args):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=args.num_dimensions, out_channels=32, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, args.num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x = torch.softmax(x, dim=1) 
        return x

<h2>Local Model Training</h2>

In [6]:
class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.train_data = DataLoader(DecentralizeDataset(dataset, idxs), batch_size=self.args.local_batchsize, shuffle=True)

    def local_training(self, net):
        net.train()
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr)
        epoch_loss = []
        
        for epoch in range(self.args.local_epochs):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.train_data):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
    
                optimizer.zero_grad()
                prediction = net(images)
                loss = self.args.loss_function(prediction, labels)
                loss.backward()
                optimizer.step()
    
                batch_loss.append(loss.item())
                if (self.args.verbose and batch_idx % 10 == 0):
                    print('Local Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(images), len(self.train_data), 100. * batch_idx / len(self.train_data), loss.item()))
    
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), (sum(epoch_loss)/len(epoch_loss))

<h2>Model Evaluation</h2>

In [7]:
def model_evaluation(net, dataset, args):
    net.eval()
                     
    running_loss = 0.0
    correct = 0
    total = 0

    validation_loader = DataLoader(dataset, batch_size=args.global_batchsize)

    predicted_labels = []
    true_labels = []

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(validation_loader):
            images, labels = images.to(args.device), labels.to(args.device)
            
            prediction = net(images)
            loss = args.loss_function(prediction, labels)
            running_loss += loss.item()
            
            _, predicted = torch.max(prediction, 1)
            correct += (predicted == labels).sum().item()

            predicted_labels.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            
    validation_loss = running_loss / len(validation_loader.dataset)
    validation_accuracy = correct / len(validation_loader.dataset)

    return validation_loss, validation_accuracy

<h2>Federated Averaging Algorithm</h2>

In [8]:
def FedAvg(local_ws, clients):
    avg_w = {}
    for k in local_ws[0].keys():
        sum_w = torch.zeros_like(local_ws[0][k])
        for i in range(1, len(local_ws)):
            sum_w += torch.mul(local_ws[i][k], clients[i])
        avg_w[k] = torch.div(sum_w, sum(clients))
    return avg_w

<h2>Global Model Training</h2>

In [9]:
def training(global_model, training_dataset, validation_dataset, dict_clients, args):
    global_loss, global_acc = [], []

    global_model.train()
    global_w = global_model.state_dict()
    
    for round in range(args.total_rounds):
        local_ws, local_losses = [], []
        m = max(int(args.fraction_clients * args.clients_per_round), 1)
        selected_m = np.random.choice(range(args.clients_per_round), m, replace=False)

        for k in selected_m:
            local = LocalUpdate(args=args, dataset=training_dataset, idxs=dict_clients[k])
            w, loss = local.local_training(net=copy.deepcopy(global_model).to(args.device))
            local_ws.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # Update Global Model
        global_w = FedAvg(local_ws, selected_m)
        global_model.load_state_dict(global_w)

        # Calculate Average Loss
        avg_loss = sum(local_losses) / len(local_losses)
        global_loss.append(avg_loss)

        # Calculate GLobal Model Evaluation Loss & Accuracy 
        global_model.eval()
        _, global_round_acc = model_evaluation(global_model, validation_dataset, args)
        print('Round {:3d}, Average loss {:.4f}, Accuracy {:.4f}'.format(round + 1, avg_loss, global_round_acc))

        global_acc.append(global_round_acc)

    global_model.eval()
    _, final_train_acc = model_evaluation(global_model, training_dataset, args)
    _, final_valid_acc = model_evaluation(global_model, validation_dataset, args)

    print("\nFinal Training accuracy: {:.4f}".format(final_train_acc))
    print("Final Testing accuracy: {:.4f}".format(final_valid_acc))
    
    return global_loss, global_acc

<h2>Hyperparameters for IID</h2>

In [None]:
class FederatedSettings:
    def __init__(self, device, loss_function, num_classes, num_dimensions, lr, global_batchsize, total_rounds, verbose, clients_per_round, fraction_clients, local_batchsize, local_epochs):
        self.device = device
        self.loss_function = loss_function
        self.num_classes = num_classes
        self.num_dimensions = num_dimensions
        
        self.lr = lr
        self.global_batchsize = global_batchsize
        self.total_rounds = total_rounds                    
        self.verbose = verbose
        
        self.clients_per_round = clients_per_round
        self.fraction_clients = fraction_clients    
        self.local_batchsize = local_batchsize
        self.local_epochs = local_epochs

args = FederatedSettings(
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    loss_function = nn.CrossEntropyLoss(),             # Objective Function
    num_dimensions = 1,                                # Input Shape
    num_classes = 10,                                  # Output Classes

    lr = 0.1,                                          # Learning Rate 
    global_batchsize = 128,                            # Global Batch Size
    total_rounds = 1000,                               # Global Epochs or 'Communication rounds'    
    verbose = False,
    
    clients_per_round = 100,                           # Clients participating per round (K)
    fraction_clients = 0.1,                            # Fraction of Clients (C)
    local_batchsize = 10,                              # Local Minibatch size (B)
    local_epochs = 5,                                 # Local Epochs (E)
)

dict_clients_IID = IID(training_dataset, args.clients_per_round)
global_model_IID = CNN(args=args).to(args.device)

global_loss_IID, global_acc_IID = training(global_model_IID, training_dataset, validation_dataset, dict_clients_IID, args)
torch.save(global_model_IID, "FedAvg_IID")

Round   1, Average loss 1.8554, Accuracy 0.9134
Round   2, Average loss 1.5506, Accuracy 0.9501
Round   3, Average loss 1.5164, Accuracy 0.9596
Round   4, Average loss 1.5074, Accuracy 0.9670


<h2>Hyperparameters for Non-IID</h2>

In [None]:
class FederatedSettings:
    def __init__(self, device, loss_function, num_classes, num_dimensions, lr, global_batchsize, total_rounds, verbose, clients_per_round, fraction_clients, local_batchsize, local_epochs):
        self.device = device
        self.loss_function = loss_function
        self.num_classes = num_classes
        self.num_dimensions = num_dimensions
        
        self.lr = lr
        self.global_batchsize = global_batchsize
        self.total_rounds = total_rounds                    
        self.verbose = verbose
        
        self.clients_per_round = clients_per_round
        self.fraction_clients = fraction_clients    
        self.local_batchsize = local_batchsize
        self.local_epochs = local_epochs

args = FederatedSettings(
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    loss_function = nn.CrossEntropyLoss(),             # Objective Function
    num_dimensions = 1,                                # Input Shape
    num_classes = 10,                                  # Output Classes

    lr = 0.1,                                          # Learning Rate 
    global_batchsize = 128,                            # Global Batch Size
    total_rounds = 1000,                               # Global Epochs or 'Communication rounds'    
    verbose = False,
    
    clients_per_round = 100,                           # Clients participating per round (K)
    fraction_clients = 0.1,                            # Fraction of Clients (C)
    local_batchsize = 10,                              # Local Minibatch size (B)
    local_epochs = 5,                                 # Local Epochs (E)
)

dict_clients_nonIID = nonIID(training_dataset, args.clients_per_round)
global_model_nonIID = CNN(args=args).to(args.device)

global_loss_nonIID, global_acc_nonIID = training(global_model_nonIID, training_dataset, validation_dataset, dict_clients_nonIID, args)
torch.save(global_model_nonIID, "FedAvg_nonIID")

<h2>Global Model Complexity</h2>

In [None]:
summary(global_model_IID, input_size=(128, 1, 28, 28), verbose=2)
summary(global_model_nonIID, input_size=(128, 1, 28, 28), verbose=2)

# global_model = torch.load('FedAvg_IID')
epochs_range = range(1, args.total_rounds + 1)

# Plot Global Training Loss
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, global_loss_IID, color='red', linestyle="dashed", label="B = 10  E = 5")
plt.xlabel('Communication Rounds')
plt.ylabel('Train Loss')
plt.legend(loc='upper right')  # Move legend to bottom right
plt.title('MNIST CNN IID')

# Plot Global Validation Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs_range, global_acc_IID, color='red', linestyle="dashed", label="B = 10  E = 5")
plt.xlabel('Communication Rounds')
plt.ylabel('Test Accuracy')
plt.legend(loc='lower right')  # Move legend to bottom right
plt.title('MNIST CNN IID')

plt.tight_layout()
plt.show()

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, global_loss_nonIDD, color='red', linestyle="dashed", label="B = 10  E = 5")
plt.xlabel('Communication Rounds')
plt.ylabel('Train Loss')
plt.legend(loc='upper right')  # Move legend to bottom right
plt.title('MNIST CNN Non-IID')

# Plot Global Validation Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs_range, global_acc_nonIID, color='red', linestyle="dashed", label="B = 10  E = 5")
plt.xlabel('Communication Rounds')
plt.ylabel('Test Accuracy')
plt.legend(loc='lower right')  # Move legend to bottom right
plt.title('MNIST CNN Non-IID')

plt.tight_layout()
plt.show()