<h1>Federated Learning: Scheduler Comparison</h1>

<t>This notebook compares *Random*, *Age-based*, and *Age-Of-Update OR DataShapley* (AoU), and Version Age-based Schedulers using the MNIST and CIFAR-10 Datasets.

> Two models are considered,
> 1. A Multi-layer Perceptron comprising of two hidden layers with 64 units each, utilizing the ReLU activation function
> 2. A Convolutional Neural Network encompassing two convolutional layers with max pooling, two fully connected layers, and a softmax output layer

> Two types of data distributions were considered when using the MNIST Dataset,
>1. IID: Data is shuffled and then divided up across 100 clients each receiving 600 examples
>2. Non-IID: Data is sorted by digit label, divided up into 200 'shards' of 300 examples, and then each client receieves 2 'shards

</t>


In [None]:
# Import Global Dependencies
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
import torch
import torch.nn.functional as F

# Import Helper Libaries
import matplotlib.pyplot as plt
from torchinfo import summary
import numpy as np
import random
import time

filepath = "C:/Users/aidan_000/Desktop/UNCC/Github/Fed-Learning/data" 

<h2>Multi-layer Perceptron Model Architecture</h2>

<t>*"The initial experiment involves training an MLP with the MNIST dataset. This MLP comprises two
hidden layers with 64 units each, utilizing the ReLU activation function"*</t>

In [None]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)  # MNIST images are 28x28
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # 10 classes for MNIST

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

<h2>CIFAR-10 CNN Model Architecture</h2>

<t>*The second experiment focuses on training a Convolutional Neural Network CNN on the CIFAR-10 dataset. The CNN
architecture encompasses two convolutional layers with max pooling, two fully connected layers, and a softmax output
layer*</t>

In [None]:
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(64*8*8, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return self.softmax(x)

<h2>Federated Learning Algorithms</h2>

<h3>Global Aggregator</h3>

In [None]:
def global_aggregate(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k] for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

<h3>Global Model Evaluation</h3>

In [None]:
def model_evaluation(global_model, validation_loader):
    global_model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in validation_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            output = global_model(inputs)
            loss += nn.CrossEntropyLoss()(output, labels).item()  # Using the criterion
            _, predicted = torch.max(output.data, 1)
            correct += (predicted == labels).sum().item()

    loss /= len(validation_loader.dataset)
    accuracy = correct / len(validation_loader.dataset)

    return loss, accuracy

<h3>Client Update</h3>

In [None]:
def client_update(client, optimizer, training_loader, epochs):
    client.train()
    for epoch in range(epochs):
        for batch_idx, (inputs, labels) in enumerate(training_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            output = client(inputs)
            loss = F.cross_entropy(output, labels)
            loss.backward()
            optimizer.step()
    return loss.item()

<h2>Training with different schedulers</h2>

<h3>Training using Random Scheduling</h3>

In [None]:
def RNG_training(model_type, global_model, lr, total_clients, clients_per_round, total_rounds, local_epochs, training_loaders, validation_loader):
    clients = [model_type().to(device) for _ in range(total_clients)]

    for models in clients:
        models.load_state_dict(global_model.state_dict())
    
    opt = [optim.SGD(models.parameters(), lr=lr) for models in clients]

    average_losses, valid_losses, valid_accuracies = [], [], []

    global_start_time = time.time()
    for round in range(total_rounds):
        start_time = time.time()
        clients_idx = np.random.permutation(total_clients)[:clients_per_round]

        client_losses = 0
        selected_models = []
        
        for i in range(clients_per_round):
            clients[clients_idx[i]].load_state_dict(global_model.state_dict())
            client_losses += client_update(clients[clients_idx[i]], opt[clients_idx[i]], training_loaders[clients_idx[i]], local_epochs)
            selected_models.append(clients[clients_idx[i]])
        
        global_aggregate(global_model, selected_models)
    
        avg_loss = client_losses / clients_per_round
        valid_loss , valid_accuracy = model_evaluation(global_model, validation_loader)

        average_losses.append(avg_loss)
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_accuracy)
    
        end_time = time.time()
        round_time = end_time - start_time
    
        if (round % 10) == 0:
            rounds_end_time = time.time()
            rounds_time = rounds_end_time - global_start_time
            print('Round {:3d}, Time (secs) {:.2f}: Average loss {:.4f}, Validation Loss {:.4f}, Validation Accuracy {:.4f}'.format(round + 1, rounds_time, avg_loss, valid_loss, valid_accuracy))
    return average_losses, valid_losses, valid_accuracies

<h3>Training using Age-based Scheduling</h3>

In [None]:
def ABS_training(model_type, global_model, lr, total_clients, clients_per_round, total_rounds, local_epochs, training_loaders, validation_loader):
    clients = [model_type().to(device) for _ in range(total_clients)]
    
    for models in clients:
        models.load_state_dict(global_model.state_dict())
    
    opt = [optim.SGD(models.parameters(), lr=lr) for models in clients]
    clients_age = np.zeros(total_clients)
    
    average_losses, valid_losses, valid_accuracies = [], [], []

    global_start_time = time.time()
    for round in range(total_rounds):
        start_time = time.time()

        # Age-based scheduling: select the clients with the highest age
        old_clients_idx = np.argsort(clients_age)[-clients_per_round:]

        clients_age += 1
        clients_age[old_clients_idx] = 0  # Reset the age of the selected clients

        client_losses = 0      
        selected_models = []
        
        for i in range(clients_per_round):
            clients[old_clients_idx[i]].load_state_dict(global_model.state_dict())
            client_losses += client_update(clients[old_clients_idx[i]], opt[old_clients_idx[i]], training_loaders[old_clients_idx[i]], local_epochs)
            selected_models.append(clients[old_clients_idx[i]])
        

        global_aggregate(global_model, selected_models)
    
        avg_loss = client_losses / clients_per_round
        valid_loss , valid_accuracy = model_evaluation(global_model, validation_loader)

        average_losses.append(avg_loss)
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_accuracy)
    
        end_time = time.time()
        round_time = end_time - start_time
    
        if (round % 10) == 0:
            rounds_end_time = time.time()
            rounds_time = rounds_end_time - global_start_time
            print('Round {:3d}, Time (secs) {:.2f}: Average loss {:.4f}, Validation Loss {:.4f}, Validation Accuracy {:.4f}'.format(round + 1, rounds_time, avg_loss, valid_loss, valid_accuracy))
    return average_losses, valid_losses, valid_accuracies

<h3>Training using Age of Update (AoU) OR Data Shapley value Scheduling</h3>

<t>If a UE’s k AoU surpasses the threshold or its Data Shapley value exceeds the Shapley value of the current highest value in the list, or both conditions are met,
the UE k is positioned at the beginning of the list; otherwise, it is placed at the end</t>

In [None]:
def create_client_map(client_models, training_loaders):
    client_map = {}
    for idx, client_model in enumerate(client_models):
        client_id = f"client_{idx + 1}"
        optimizer = optim.SGD(client_model.parameters(), lr=lr)
        client_map[client_id] = {
            'model': client_model,
            'optimizer': optimizer,
            'training_loader': training_loaders[idx],
            'accuracies': [],
            'losses': [],
            'shapley_value': random.random(),
            'age': 0
        }
    return client_map

In [None]:
def update_performance(client_map, selected_clients, shapley_threshold):
    for client_id in client_map.keys():
        if client_id in selected_clients:
            client_accuracies = client_map[client_id]["accuracies"]
            mean_acc = sum(client_accuracies) / len(client_accuracies)
            variance = sum(((x - mean_acc) ** 2) for x in client_accuracies) / len(client_accuracies)
            if variance > shapley_threshold:
                updated_value = client_map[client_id]["shapley_value"] + 1.0    
            else:
                updated_value = client_map[client_id]["shapley_value"]
        else:
            updated_value = client_map[client_id]["shapley_value"] 

        mean_performance_score = (sum(client_map[client_id]["losses"]) + updated_value) / (len(client_map[client_id]["losses"]) + 1)
        client_map[client_id]["shapley_value"] = mean_performance_score

In [None]:
def AoU_Scheduler(client_map, AoU_threshold, clients_per_round):
    selected_clients = []
    client_position = []
    highest_score = max(client_map.values(), key=lambda x: x['shapley_value'])['shapley_value']
    
    for client_id, client_info in client_map.items():
        client_score = client_info['shapley_value']
        client_age = client_info['age']
        
        if client_score >= highest_score or client_age > AoU_threshold:
            client_position.insert(0, client_id)  # Prepend the client ID
        else:
            client_position.append(client_id)

    for i in range(clients_per_round):
        selected_clients.append(client_position[i])

    return selected_clients

In [None]:
def train_and_test_client(client_map, selected_clients, local_epochs, validation_loader, global_model, performance_threshold):
    for client_id in selected_clients:
        client_map[client_id]['model'].load_state_dict(global_model.state_dict())
        model = client_map[client_id]['model']
        model.train()

        optimizer = client_map[client_id]['optimizer']
        criterion = nn.CrossEntropyLoss()

        for epoch in range(local_epochs):
            for inputs, labels in client_map[client_id]['training_loader']:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
        
        model.eval()
        correct = 0
        for inputs, labels in validation_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
        accuracy = correct / (len(validation_loader.dataset))

        client_map[client_id]['accuracies'].append(accuracy)
        client_map[client_id]['losses'].append(loss.item())
    
    update_performance(client_map, selected_clients, performance_threshold)

In [None]:
def AoU_OR_DataShapley_training(model_type, global_model, total_clients, clients_per_round, total_rounds, local_epochs, training_loaders, validation_loader, age_threshold, shapley_threshold):
    client_models = [model_type().to(device) for _ in range(total_clients)]
    for models in client_models:
        models.load_state_dict(global_model.state_dict())

    client_map = create_client_map(client_models, training_loaders)

    average_losses, valid_losses, valid_accuracies = [], [], []

    global_start_time = time.time()
    for round in range(total_rounds):
        start_time = time.time()

        selected_clients = AoU_Scheduler(client_map, age_threshold, clients_per_round)

        for client_id in client_map:
            client_map[client_id]['age'] = client_map[client_id]['age'] ** 2 + 1

        # Reset the age of selected clients to 0
        for client_id in selected_clients:
            client_map[client_id]['age'] = 0
        
        train_and_test_client(client_map, selected_clients, local_epochs, validation_loader, global_model, shapley_threshold)        
        update_performance(client_map, selected_clients, shapley_threshold)
        global_aggregate(global_model, [client_map[client_id]['model'] for client_id in selected_clients])

        client_losses = sum(client_map[client_id]['losses'][-1] for client_id in selected_clients)
    
        avg_loss = client_losses / len(selected_clients)
        valid_loss , valid_accuracy = model_evaluation(global_model, validation_loader)

        average_losses.append(avg_loss)
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_accuracy)
    
        end_time = time.time()
        round_time = end_time - start_time
    
        if (round % 10) == 0:
            rounds_end_time = time.time()
            rounds_time = rounds_end_time - global_start_time
            print('Round {:3d}, Time (secs) {:.2f}: Average loss {:.4f}, Validation Loss {:.4f}, Validation Accuracy {:.4f}'.format(round + 1, rounds_time, avg_loss, valid_loss, valid_accuracy))

    return average_losses, valid_losses, valid_accuracies


<h3>Training using Version Age-based (VAoI)</h3>

<t></t>

In [None]:
def VAoI_client_map(client_models, training_loaders):
    client_map = {}
    for idx, client_model in enumerate(client_models):
        client_id = f"client_{idx + 1}"
        optimizer = optim.SGD(client_model.parameters(), lr=lr)
        client_map[client_id] = {
            'model': client_model,
            'optimizer': optimizer,
            'training_loader': training_loaders[idx],
            'accuracies': [],
            'losses': [],
            'version_age': 0
        }
    return client_map

In [None]:
def manhattan_norm(client_model, global_model):
    client_weights = client_model.state_dict()
    global_weights = global_model.state_dict()
    for key in client_weights.keys():
        norm = torch.norm(client_weights[key] - global_weights[key], p=1)
        norm = norm.item()
    return norm

In [None]:
def VAoI_Scheduler(client_map, global_model, tau, h=lambda x: np.exp(x)):
    num_clients = len(client_map)
    client_version_ages = []

    # Calculate the version age for each client
    for client_id, client_info in client_map.items():
        model = client_info['model']
        distance = manhattan_norm(model, global_model)
        if distance >= tau:
            client_info['version_age'] += 1
        client_version_ages.append(client_info['version_age'])

    # Calculate the selection probabilities based on version age
    selection_probs = [h(age) for age in client_version_ages]
    selection_probs = [prob / sum(selection_probs) for prob in selection_probs]

    # Select the clients based on the probabilities
    selected_clients = np.random.choice(list(client_map.keys()), size=int(0.1 * num_clients), p=selection_probs, replace=False)

    # Reset the version age for the selected clients
    for client_id in selected_clients:
        client_map[client_id]['version_age'] = 0

    return selected_clients

In [None]:
def VAoI_training(model_type, global_model, lr, total_clients, clients_per_round, total_rounds, local_epochs, training_loaders, validation_loader, tau):
    client_models = [model_type().to(device) for _ in range(total_clients)]
    for models in client_models:
        models.load_state_dict(global_model.state_dict())

    client_map = VAoI_client_map(client_models, training_loaders)
    
    opt = [optim.SGD(models.parameters(), lr=lr) for models in client_models]
    
    average_losses, valid_losses, valid_accuracies = [], [], []

    global_start_time = time.time()
    for round in range(total_rounds):
        start_time = time.time()

        # Use the VAoI scheduler to select the clients
        selected_clients = VAoI_Scheduler(client_map, global_model, tau, h=lambda x: np.exp(x))

        client_losses = 0
        selected_models = []
        for client_id in selected_clients:
            client_info = client_map[client_id]
            client_model = client_info['model']
            client_optimizer = client_info['optimizer']
            client_training_loader = client_info['training_loader']

            # Load the global model weights to the selected client
            client_model.load_state_dict(global_model.state_dict())

            # Perform local training on the selected client
            client_loss = client_update(client_model, client_optimizer, client_training_loader, local_epochs)
            client_losses += client_loss
            selected_models.append(client_model)

        # Aggregate the selected client models to update the global model
        global_aggregate(global_model, selected_models)

        avg_loss = client_losses / len(selected_clients)
        valid_loss, valid_accuracy = model_evaluation(global_model, validation_loader)

        average_losses.append(avg_loss)
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_accuracy)

        end_time = time.time()
        round_time = end_time - start_time

        if (round % 10) == 0:
            rounds_end_time = time.time()
            rounds_time = rounds_end_time - global_start_time
            print('Round {:3d}, Time (secs) {:.2f}: Average loss {:.4f}, Validation Loss {:.4f}, Validation Accuracy {:.4f}'.format(round + 1, rounds_time, avg_loss, valid_loss, valid_accuracy))

    return average_losses, valid_losses, valid_accuracies

<h2>Hyperparameters for Training Experience</h2>

In [None]:
# Device configuration: use CUDA if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Learning configuration
lr = 0.015
total_rounds = 100  # Total number of training rounds

# Client configuration
total_clients = 100  # Total number of clients
clients_per_round = 10  # Number of clients selected per round

# Local training configuration
local_batchsize = 50  # Batch size for local training
local_epochs = 5  # Number of epochs for local training

shapley_threshold = 0.8
age_threshold = 10

tau = 0.1

<h2>IID Data Preparation for MNIST and CIFAR-10 Dataset</h2>
<t>*The IID data is shuffled and then divided up across 100 clients each receiving 600 examples.*  

*Exclusively use independent and identically distributed i.i.d. distributions for CIFAR-10 due to the absence of a natural
data user partition*</t>

In [None]:
MNISTtransform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# MNIST IID Dataset
MNIST_training_dataset = datasets.MNIST(filepath, train=True, download=True, transform=MNISTtransform)
MNIST_training_datasplit = torch.utils.data.random_split(MNIST_training_dataset, [int(MNIST_training_dataset.data.shape[0] / total_clients) for _ in range(total_clients)])
MNIST_iid_training = [torch.utils.data.DataLoader(x, batch_size=local_batchsize, shuffle=True) for x in MNIST_training_datasplit]

MNIST_validation_dataset = datasets.MNIST(filepath, train=False, download = True, transform=MNISTtransform)
MNIST_iid_validation = torch.utils.data.DataLoader(MNIST_validation_dataset, batch_size=local_batchsize, shuffle=True)


# Load CIFAR-10 dataset
CIFARtransform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 
])

dataset = datasets.CIFAR10(filepath, train=True, download=True, transform=CIFARtransform)

CIFAR10_dataset = torch.utils.data.random_split(dataset, [len(dataset) // total_clients for _ in range(total_clients)])
CIFAR10_training = [torch.utils.data.DataLoader(x, batch_size=local_batchsize, shuffle=True) for x in CIFAR10_dataset]

CIFAR10_validation = torch.utils.data.DataLoader(datasets.CIFAR10(filepath, train=False, transform=CIFARtransform), batch_size=local_batchsize, shuffle=True)

<h3>MLP Model Training with MNIST IID</h3>

Training is done using *'Random'* Scheduling, *'Age-based'* Scheduling, and *'Age of Update OR Data Shapley'* Scheduling (See RNG_training, ABS_training, AoU_OR_DataShapley)

In [180]:
# MNIST_iid_RNG = MLP().to(device)
# MNIST_iid_ABS = MLP().to(device)
# MNIST_iid_AoU_OR_DataShapley = MLP().to(device)
MNIST_iid_VAoI = MLP().to(device)

# print("=== Training: Model - MLP, Schedule - Random, Data Distribution - MNIST IID  ===")
# iid_MNIST_RNG_avg_losses, iid_MNIST_RNG_eval_losses, iid_MNIST_RNG_accuracies = RNG_training(MLP, MNIST_iid_RNG, lr, total_clients, clients_per_round, total_rounds, local_epochs, MNIST_iid_training, MNIST_iid_validation)

# print("\n=== Training: Model - MLP, Schedule - Age-Based, Data Distribution - MNIST IID ===")
# iid_MNIST_ABS_avg_losses, iid_MNIST_ABS_eval_losses, iid_MNIST_ABS_accuracies = ABS_training(MLP, MNIST_iid_ABS, lr, total_clients, clients_per_round, total_rounds, local_epochs, MNIST_iid_training, MNIST_iid_validation)

# print("\n=== Training: Model - MLP, Schedule - Age of Update OR Data Shapley, Data Distribution - MNIST IID ===")
# iid_MNIST_AoU_OR_DataShapley_avg_losses, iid_MNIST_AoU_OR_DataShapley_eval_losses, iid_MNIST_AoU_OR_DataShapley_accuracies = AoU_OR_DataShapley_training(MLP, MNIST_iid_AoU_OR_DataShapley, total_clients, clients_per_round, total_rounds, local_epochs, MNIST_iid_training, MNIST_iid_validation, age_threshold, shapley_threshold)

print("\n=== Training: Model - MLP, Schedule - Version Age of Information, Data Distribution - MNIST IID ===")
iid_MNIST_VAoI_avg_losses, iid_MNIST_VAoI_eval_losses, iid_MNIST_VAoI_accuracies = VAoI_training(MLP, MNIST_iid_VAoI, lr, total_clients, clients_per_round, total_rounds, local_epochs, MNIST_iid_training, MNIST_iid_validation, tau)

# # Save Final Models
# torch.save(MNIST_iid_RNG.state_dict(), '.\\Models\\MNIST_iid_RNG.pth')
# torch.save(MNIST_iid_ABS.state_dict(), '.\\Models\\MNIST_iid_ABS.pth')
# torch.save(MNIST_iid_AoU_OR_DataShapley.state_dict(), '.\\Models\\MNIST_iid_AoU_OR_DataShapley.pth')

Round   1, Time (secs) 6.79: Average loss 2.0328, Validation Loss 0.0410, Validation Accuracy 0.4623


KeyboardInterrupt: 

<h4>MLP Model Training/Inferencing Experience Comparison for MNIST IID</h4>

In [None]:
print(f'=================================== Final MLP Model Accuracies per Schedule ====================================')
print(f'Random Scheduled MNIST IID Model Accuracy: {iid_MNIST_RNG_accuracies[-1]},\nAge-based Scheduled MNIST IID  Model Accuracy: {iid_MNIST_ABS_accuracies[-1]},\nAoU OR DataShapley Scheduled MNIST IID Model Accuracy: {iid_MNIST_AoU_OR_DataShapley_accuracies[-1]}')
print(f'================================================================================================================')

epochs_range = range(1, total_rounds + 1)

# Plot Global Training Loss
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, iid_MNIST_RNG_eval_losses, color='red', linestyle="dashed", label="Random Schedule")
plt.plot(epochs_range, iid_MNIST_ABS_eval_losses, color='blue', linestyle="dotted", label="Age-based Schedule")
plt.plot(epochs_range, iid_MNIST_AoU_OR_DataShapley_eval_losses, color='green', linestyle="solid", label="AoU OR DataShapley Schedule")
plt.xlabel('Communication Rounds')
plt.ylabel('Train Loss')
plt.legend(loc='upper right')  
plt.title('MLP MNIST IID Loss Curve')

# Plot Global Validation Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs_range, iid_MNIST_RNG_accuracies, color='red', linestyle="dashed", label="Random Schedule")
plt.plot(epochs_range, iid_MNIST_ABS_accuracies, color='blue', linestyle="dotted", label="Age-based Schedule")
plt.plot(epochs_range, iid_MNIST_AoU_OR_DataShapley_accuracies, color='green', linestyle="solid", label="AoU OR DataShapley Schedule")
plt.xlabel('Communication Rounds')
plt.ylabel('Test Accuracy')
plt.legend(loc='lower right') 
plt.title('MLP MNIST IID Accuracy Curve')

plt.tight_layout()
plt.savefig('.\\Plots\\mnist_iid_results.png')
plt.show()

<h3>CNN Model Training with CIFAR-10 IID</h3>

Training is done using *'Random'* Scheduling, *'Age-based'* Scheduling, and *'Age of Update OR Data Shapley'* Scheduling (See RNG_training, ABS_training, AoU_OR_DataShapley)

In [None]:
# CIFAR10_iid_RNG = CNN().to(device)
# CIFAR10_iid_ABS = CNN().to(device)
# CIFAR10_iid_AoU_OR_DataShapley = CNN().to(device)

# print("=== Training: Model - CNN, Schedule - Random, Data Distribution - IID CIFAR-10 ===")
# iid_CIFAR10_RNG_avg_losses, iid_CIFAR10_RNG_eval_losses, iid_CIFAR10_RNG_accuracies = RNG_training(CNN, CIFAR10_iid_RNG, lr, total_clients, clients_per_round, total_rounds, local_epochs, CIFAR10_training, CIFAR10_validation)

# print("\n=== Training: Model - CNN, Schedule - Age-Based, Data Distribution - IID CIFAR-10 ===")
# iid_CIFAR10_ABS_avg_losses, iid_CIFAR10_ABS_eval_losses, iid_CIFAR10_ABS_accuracies = ABS_training(CNN, CIFAR10_iid_ABS, lr, total_clients, clients_per_round, total_rounds, local_epochs, CIFAR10_training, CIFAR10_validation)

# print("\n=== Training: Model - CNN, Schedule - Age of Update OR Data Shapley, Data Distribution - IID CIFAR-10 ===")
# iid_CIFAR10_AoU_OR_DataShapley_avg_losses, iid_CIFAR10_AoU_OR_DataShapley_eval_losses, iid_CIFAR10_AoU_OR_DataShapley_accuracies = AoU_OR_DataShapley_training(CNN, CIFAR10_iid_AoU_OR_DataShapley, lr, total_clients, clients_per_round, total_rounds, local_epochs, CIFAR10_training, CIFAR10_validation, age_threshold, shapley_threshold)

# # Save Final Models
# torch.save(CIFAR10_iid_RNG.state_dict(), '.\\Models\\CIFAR10_iid_RNG.pth')
# torch.save(CIFAR10_iid_ABS.state_dict(), '.\\Models\\CIFAR10_iid_ABS.pth')
# torch.save(CIFAR10_iid_AoU_OR_DataShapley.state_dict(), '.\\Models\\CIFAR10_iid_AoU_OR_DataShapley.pth')

<h4>CNN Model Training/Inferencing Experience Comparison for CIFAR-10 IID</h4>

In [None]:
print(f'=================================== Final CNN Model Accuracies per Schedule ====================================')
print(f'Random Scheduled CIFAR-10 IID Model Accuracy: {iid_CIFAR10_RNG_accuracies[-1]},\nAge-based Scheduled CIFAR-10 IID  Model Accuracy: {iid_CIFAR10_ABS_accuracies[-1]},\nAoU OR DataShapley Scheduled CIFAR-10 IID Model Accuracy: {iid_CIFAR10_AoU_OR_DataShapley_accuracies[-1]}')
print(f'================================================================================================================')

epochs_range = range(1, total_rounds + 1)

# Plot Global Training Loss
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, iid_CIFAR10_RNG_eval_losses, color='red', linestyle="dashed", label="Random Schedule")
plt.plot(epochs_range, iid_CIFAR10_ABS_eval_losses, color='blue', linestyle="dotted", label="Age-based Schedule")
plt.plot(epochs_range, iid_CIFAR10_AoU_OR_DataShapley_eval_losses, color='green', linestyle="solid", label="AoU OR DataShapley Schedule")
plt.xlabel('Communication Rounds')
plt.ylabel('Train Loss')
plt.legend(loc='upper right')  
plt.title('CNN CIFAR-10 IID Loss Curve')

# Plot Global Validation Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs_range, iid_CIFAR10_RNG_accuracies, color='red', linestyle="dashed", label="Random Schedule")
plt.plot(epochs_range, iid_CIFAR10_ABS_accuracies, color='blue', linestyle="dotted", label="Age-based Schedule")
plt.plot(epochs_range, iid_CIFAR10_AoU_OR_DataShapley_accuracies, color='green', linestyle="solid", label="AoU OR DataShapley Schedule")
plt.xlabel('Communication Rounds')
plt.ylabel('Test Accuracy')
plt.legend(loc='lower right') 
plt.title('CNN CIFAR-10 IID Accuracy Curve')

plt.tight_layout()
plt.savefig('.\\Plots\\cifar10_iid_results.png')
plt.show()

<h2>Non-IID Data Preparation for MNIST Dataset</h2>

<t>*The Non-IID data is sorted by digit label, divided up into 200 'shards' of 300 examples, and then each client receieves 2 'shards'*</t>

In [None]:
# MNIST Non-IID Dataset
training_labels = torch.stack([MNIST_training_dataset.targets == i for i in range(10)])
training_labels_split = []

for i in range(5):
    training_labels_split += torch.split(torch.where(training_labels[(2 * i):(2 * (i + 1))].sum(0))[0], int(len(MNIST_training_dataset.data) / total_clients))
training_dataset_split = [torch.utils.data.Subset(MNIST_training_dataset, labels) for labels in training_labels_split]
MNIST_noniid_training = [torch.utils.data.DataLoader(x, batch_size=local_batchsize, shuffle=True) for x in training_dataset_split]

MNIST_noniid_validation = torch.utils.data.DataLoader(MNIST_validation_dataset, batch_size=local_batchsize, shuffle=True)

<h3>MLP Model Training with MNIST Non-IID</h3>

Training is done using *'Random'* Scheduling, *'Age-based'* Scheduling, and *'Age of Update OR Data Shapley'* Scheduling (See RNG_training, ABS_training, AoU_OR_DataShapley)

In [None]:
# MNIST_noniid_RNG = MLP().to(device)
# MNIST_noniid_ABS = MLP().to(device)
# MNIST_noniid_AoU_OR_DataShapley = MLP().to(device)

# print("=== Training: Model - MLP, Schedule - Random, Data Distribution - MNIST Non-IID  ===")
# noniid_MNIST_RNG_avg_losses, noniid_MNIST_RNG_eval_losses, noniid_MNIST_RNG_accuracies = RNG_training(MLP, MNIST_noniid_RNG, lr, total_clients, clients_per_round, total_rounds, local_epochs, MNIST_noniid_training, MNIST_noniid_validation)

# print("\n=== Training: Model - MLP, Schedule - Age-Based, Data Distribution - MNIST Non-IID  ===")
# noniid_MNIST_ABS_avg_losses, noniid_MNIST_ABS_eval_losses, noniid_MNIST_ABS_accuracies = ABS_training(MLP, MNIST_noniid_ABS, lr, total_clients, clients_per_round, total_rounds, local_epochs, MNIST_noniid_training, MNIST_noniid_validation)

# print("\n=== Training: Model - MLP, Schedule - Age of Update OR Data Shapley, Data Distribution - MNIST Non-IID  ===")
# noniid_MNIST_AoU_OR_DataShapley_avg_losses, noniid_MNIST_AoU_OR_DataShapley_eval_losses, noniid_MNIST_AoU_OR_DataShapley_accuracies = AoU_OR_DataShapley_training(MLP, MNIST_noniid_AoU_OR_DataShapley, lr, total_clients, clients_per_round, total_rounds, local_epochs, MNIST_noniid_training, MNIST_noniid_validation, age_threshold, shapley_threshold)

# torch.save(MNIST_noniid_RNG.state_dict(), '.\\Models\\MNIST_noniid_RNG.pth')
# torch.save(MNIST_noniid_ABS.state_dict(), '.\\Models\\MNIST_noniid_ABS.pth')
# torch.save(MNIST_noniid_AoU_OR_DataShapley.state_dict(), '.\\Models\\MNIST_noniid_AoU_OR_DataShapley.pth')

<h4>MLP Model Training/Inferencing Experience Comparison for MNIST Non-IID</h4>

In [None]:
print(f'=================================== Final MLP Model Accuracies per Schedule ====================================')
print(f'Random Scheduled MNIST Non-IID Model Accuracy: {noniid_MNIST_RNG_accuracies[-1]},\nAge-based Scheduled MNIST Non-IID  Model Accuracy: {noniid_MNIST_ABS_accuracies[-1]},\nAoU OR DataShapley Scheduled MNIST Non-IID Model Accuracy: {noniid_MNIST_AoU_OR_DataShapley_accuracies[-1]}')
print(f'================================================================================================================')

epochs_range = range(1, total_rounds + 1)

# Plot Global Training Loss
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, noniid_MNIST_RNG_eval_losses, color='red', linestyle="dashed", label="Random Schedule")
plt.plot(epochs_range, noniid_MNIST_ABS_eval_losses, color='blue', linestyle="dotted", label="Age-based Schedule")
plt.plot(epochs_range, noniid_MNIST_AoU_OR_DataShapley_eval_losses, color='green', linestyle="solid", label="AoU OR DataShapley Schedule")
plt.xlabel('Communication Rounds')
plt.ylabel('Train Loss')
plt.legend(loc='upper right')  
plt.title('MLP MNIST Non-IID Loss Curve')

# Plot Global Validation Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs_range, noniid_MNIST_RNG_accuracies, color='red', linestyle="dashed", label="Random Schedule")
plt.plot(epochs_range, noniid_MNIST_ABS_accuracies, color='blue', linestyle="dotted", label="Age-based Schedule")
plt.plot(epochs_range, noniid_MNIST_AoU_OR_DataShapley_accuracies, color='green', linestyle="solid", label="AoU OR DataShapley Schedule")
plt.xlabel('Communication Rounds')
plt.ylabel('Test Accuracy')
plt.legend(loc='lower right') 
plt.title('MLP MNIST Non-IID Accuracy Curve')

plt.tight_layout()
plt.savefig('.\\Plots\\mnist_noniid_results.png')
plt.show()

<h2>Global Model Complexity Comparisons</h2>

In [None]:
def print_model_summary(model, model_name, dataset_name):
    if dataset_name == "MNIST":
        input_size = (1, 28, 28)
    elif dataset_name == "CIFAR-10":
        input_size = (3, 32, 32)

    model_stats = summary(model, input_size=input_size, col_width=16, col_names=["kernel_size", "output_size", "num_params", "mult_adds"], row_settings=["var_names"])
    print(f"\n{'='*30} {model_name} {'='*30}")
    print(model_stats)
    print('='*80)

print("\t\t\t\tMLP Model Comparisons using MNIST IID\n")
print_model_summary(MNIST_iid_RNG, "Random IID Model", "MNIST")
print_model_summary(MNIST_iid_ABS, "Age-Based IID Model", "MNIST")
print_model_summary(MNIST_iid_AoU_OR_DataShapley, "AoU OR DataShapley IID Model", "MNIST")
print()

print("\t\t\t\tMLP Model Comparisons using MNIST Non-IID\n")
print_model_summary(MNIST_noniid_RNG, "Random Non-IID Model", "MNIST")
print_model_summary(MNIST_noniid_ABS, "Age-Based Non-IID Model", "MNIST")
print_model_summary(MNIST_noniid_AoU_OR_DataShapley, "AoU OR DataShapley Non-IID Model", "MNIST")
print()

print("\t\t\t\tCNN Model Comparisons using CIFAR10 IID\n")
print_model_summary(CIFAR10_iid_RNG, "Random IID Model", "CIFAR-10")
print_model_summary(CIFAR10_iid_ABS, "Age-Based IID Model", "CIFAR-10")
print_model_summary(CIFAR10_iid_AoU_OR_DataShapley, "AoU OR DataShapley IID Model", "CIFAR-10")
