# Initialization

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.utils.data as data_utils
import random
from copy import deepcopy

statistical_heterogeneity = 5

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)  # Change from 1 to 3 input channels
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)  # Adjust the size for CIFAR-10
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 5 * 5)  # Adjust the size for CIFAR-10
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def create_non_iid_partitions(dataset, num_clients):
    num_classes = 10
    class_indices = [[] for _ in range(num_classes)]
    
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    # Ensure randomness in class selection for each client
    all_classes = list(range(num_classes))
    
    client_local_datasets = []
    for i in range(num_clients):
        # Randomly select two major classes for each client
        major_classes = random.sample(all_classes, 2)

        # Allocate all data from the two major classes
        client_indices = class_indices[major_classes[0]] + class_indices[major_classes[1]]
        
        #Add a small number of samples from other classes
        minor_indices = []
        for cls in set(range(num_classes)) - set(major_classes):
            n_samples = len(class_indices[cls]) // num_clients // statistical_heterogeneity  # 50 times less than major classes
            minor_indices.extend(class_indices[cls][i * n_samples: (i + 1) * n_samples])

        client_indices = client_indices + minor_indices
        random.shuffle(client_indices)  # Shuffle to mix data from different classes
        client_local_datasets.append(data_utils.Subset(dataset, client_indices))

    return client_local_datasets

def create_iid_partitions(dataset, num_clients):

    client_datasets = []
    for i in range(num_clients):
        client_dataset = data_utils.Subset(dataset, list(range(i * len(dataset) // num_clients, (i + 1) * len(dataset) // num_clients)))
        client_datasets.append(client_dataset)
    return client_datasets

# Training set up using CIFAR-10

In [11]:
myGPU = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(myGPU)
# load the data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for RGB channels
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_function = nn.CrossEntropyLoss()
global_epochs = 20
local_epochs = 10
number_of_clients = 5
client_datasets=create_iid_partitions(trainset, number_of_clients)

Files already downloaded and verified
Files already downloaded and verified


# Federated Learning

In [12]:
def server_aggregate(model, state_dict_list):
    # average the model
    aggregated_state = {}
    for key,parameter in model.named_parameters():
        tensor_to_aggregate = []
        for client_state_dict in state_dict_list:
            client_tensor = client_state_dict[key].float()
            tensor_to_aggregate.append(client_tensor)
        
        stacked_tensor = torch.stack(tensor_to_aggregate,dim=0)
        mean_tensor = torch.mean(stacked_tensor,dim=0)
        aggregated_state[key] = mean_tensor
    model.load_state_dict(aggregated_state)

    return model 

def difference_models_norm_2(local_model, initial_model):
    tensor_1 = list(local_model.parameters())
    tensor_2 = list(initial_model.parameters())
    sub_norm = []
    for i in range(len(tensor_1)):
        s = torch.norm(tensor_1[i].to(myGPU) - tensor_2[i].to(myGPU),p=2)
        sub_norm.append(s)
    return sum(sub_norm)

def client_update(received_model, train_data, local_optimizer, loss_f, epoch,client_id,mu,algorithm,sys_heter):
    local_model = received_model.to(myGPU)
    initial_model = received_model.to(myGPU)

    random_chance = random.randint(0, 10) # Randomly decide if the client is weak or strong
    if random_chance >= sys_heter or algorithm == "FedProx":
    
        print(f"Client {client_id+1} starts training...")
        local_model.train()
        
        for i in range(epoch):
            running_loss = 0.0

            for feature, label in train_data:
                local_optimizer.zero_grad()
                feature, label = feature.to(myGPU), label.to(myGPU)
                outputs = local_model(feature)
                local_loss = loss_f(outputs, label)
                loss_prox = (mu / 2) * difference_models_norm_2(local_model, initial_model) # perform model updates penalization using proximal term
                loss = local_loss + loss_prox
                loss.backward()
                local_optimizer.step()
                running_loss += loss.item()
      
            print(f"Epoch {i+1} loss: {running_loss / len(train_data)}")
        print("\n")
        return local_model.state_dict()
        # return model
    
    else: # If the client is weak, it will not train
        print(f"Client {client_id+1} dropped")
        return local_model.state_dict()

def evaluate(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for test_feature, test_labels in testloader:
         
            test_feature, test_labels = test_feature.to(myGPU), test_labels.to(myGPU)
            outputs = model(test_feature)
            test_loss = loss_function(outputs, test_labels)
            _, predicted = torch.max(outputs.data, 1)
            total += test_labels.size(0)
            correct += (predicted == test_labels).sum().item()
            test_loss += test_loss.item()
    accuracy = 100 * correct / total
    loss = test_loss / len(testloader)
    loss = loss.item()

    # print(f"Accuracy of the network: {accuracy}, Loss: {loss}")
    return [accuracy, loss]

def federated_learning(model, mu,client_datasets, testloader, optimizer, loss_function, global_epochs, local_epochs,algorithm,sys_heter):
    
    for i, client_dataset in enumerate(client_datasets):
        client_labels = [trainset.targets[idx] for idx in client_dataset.indices]  # Access the labels for each subset
        unique_labels = np.unique(client_labels)
        # print(f"Client {i} has {len(unique_labels)} unique labels: {unique_labels}, and {len(client_dataset)} samples")
        
    # initial_phase = evaluate(model, testloader)     
    
    # Create a dataloader for each client
    client_dataloaders = [data_utils.DataLoader(dataset, batch_size=256, shuffle=True, num_workers=2) for dataset in client_datasets]
    global_history = []
    global_model = model.to(myGPU)
    for global_epoch in range(global_epochs):
        
        print(f"Global Round {global_epoch+1}:") 
        # print(f"'{'Global Round':=^100}'")
        state_dicts = []
        for client_id, client_dataloader in enumerate(client_dataloaders):
            client_state_dict = client_update(model, client_dataloader, optimizer, loss_function, local_epochs,client_id,mu,algorithm,sys_heter)
            state_dicts.append(client_state_dict)
        
        global_model = server_aggregate(global_model, state_dicts)
        model.load_state_dict(global_model.state_dict())
        
    return global_model

# Reinforcement Learning Environment setup.

In [13]:
import gym
from gym import spaces

class FedProxTuningEnv(gym.Env):
    def __init__(self, fl_model,train_datasets, test_loader, optimizer, loss_function, global_epochs=1, local_epochs=2, algorithm='FedProx', sys_heter=5):
        super(FedProxTuningEnv, self).__init__()

        self.fl_model = fl_model
        self.train_datasets = train_datasets
        self.test_loader = test_loader
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.global_epochs = global_epochs
        self.local_epochs = local_epochs
        self.algorithm = algorithm
        self.sys_heter = sys_heter

        self.current_mu = 0.1  # Initial mu value
        self.current_round = 0

        # Define action and observation space
        self.action_space = spaces.Discrete(2)  # Actions: 0 for decreasing mu, 1 for increasing mu
        self.observation_space = spaces.Box(low=np.array([0.0, 0]), high=np.array([10.0, global_epochs]), dtype=np.float32)

    def reset(self):
        self.current_round = 0
        # Use the federated learning model for evaluation
        print(f"\n'{'Restarting Environment':=^100}'\n")
        initial_metrics = evaluate(self.fl_model, self.test_loader)  # Make sure this is the CNN model
        print("Current accuracy: ", initial_metrics[0], "Current loss: ", initial_metrics[1])
        print("Current mu: ", self.current_mu)
        initial_loss = initial_metrics[1].item() if torch.is_tensor(initial_metrics[1]) else initial_metrics[1]
        return np.array([self.current_mu, initial_loss], dtype=np.float32)

    def step(self, action):
        # Adjust mu based on action
        print(f'Action take: {action}')
       
        if action == 0 and self.current_mu > 0.01:  # Ensure mu stays positive
            print(f'Decreasing mu to {self.current_mu - 0.01}')
            self.current_mu -= 0.01
        elif action == 1 and self.current_mu < 10.0:  # Upper bound for mu
            print(f'Increasing mu to {self.current_mu + 0.01}')
            self.current_mu += 0.01
        print(f"\n'{'Stepping':=^100}'\n")
        previous_metrics = evaluate(self.fl_model, self.test_loader)
        # Run one global epoch of federated learning with the current mu
        # print(f'Running global epoch {self.current_round + 1} with mu = {self.current_mu}')
        print(f"\n'{'Start Federated Learning':-^100}'\n")
        global_model = federated_learning(self.fl_model, self.current_mu, self.train_datasets, self.test_loader, self.optimizer, self.loss_function, self.global_epochs, self.local_epochs, self.algorithm, self.sys_heter)
        print(f"\n'{'End Federated Learning':-^100}'\n")
        current_metrics = evaluate(global_model, self.test_loader)
        
        # Update the state
        self.current_round += 1
        state = np.array([self.current_mu, current_metrics[1]])  # Using loss as part of the state
        # Reward is based on improvement in accuracy
        reward = current_metrics[0] - previous_metrics[0]  # Change in accuracy
        print(f'Previous Accuracy: {previous_metrics[0]}')
        print(f'Current Accuracy: {current_metrics[0]}')
        print(f'New state: {state}')
        print(f'Reward: {reward}')
        done = self.current_round >= self.global_epochs

        if reward > 0:
            action = 0
        else :
            action = 1

        return state, reward, done, {}

    def render(self, mode='console'):
        if mode == 'console':
            print(f'Round: {self.current_round}, Mu: {self.current_mu}')


In [14]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

fl_model = CNN().to(myGPU)
# old_global_model = None
optimizer = optim.SGD(fl_model.parameters(), lr=0.01, momentum=0.9)
# Instantiate the environment
env = FedProxTuningEnv(fl_model, client_datasets, testloader, optimizer, loss_function)

# Instantiate the agent
modelRL = PPO("MlpPolicy", env, verbose=1)

# Train the agent
modelRL.learn(total_timesteps=1)


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




Current accuracy:  10.6 Current loss:  0.11586066335439682
Current mu:  0.1
Action take: 1
Increasing mu to 0.11



'--------------------------------------Start Federated Learning--------------------------------------'

Global Round 1:
Client 1 starts training...
Epoch 1 loss: 2.3020146250724793
Epoch 2 loss: 2.286200922727585


Client 2 starts training...
Epoch 1 loss: 2.1921368509531023
Epoch 2 loss: 2.012925484776497


Client 3 starts training...
Epoch 1 loss: 1.8834530115127563
Epoch 2 loss: 1.7849530011415482


Client 4 starts training...
Epoch 1 loss: 1.7410993814468383
Epoch 2 loss: 1.6606433540582657


Client 5 starts training...
Epoch 1 loss: 1.617758470773697
Epoch 2 loss: 1.569148251414299



'---------------------------------------End Federated Learning---------------------------------------'

Previous Accuracy: 10.6
Current Accuracy: 43.3
New state: [0.11       0.07040244]
Reward: 32.699999999999996


Current accuracy:  43.3 Current loss:  0.07040244340896606
Current mu:  

KeyboardInterrupt: 