# Lee et al. Multilayer Perceptron
In this notebook, I provide an implementation of the multilayer perceptron (MLP) proposed by Lee et al. (Lee et al., 2022) in a federated enviroment. Note the proposed MLP is altered for for the purpose of binary sequence-to-label classification; corresponding comments detail all adjustments. The spectra data are first obtained from all files in the "dataset" folder.

In [None]:
# Dependencies
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset
import flwr as fl
from collections import OrderedDict
from typing import List, Tuple, Dict
from sklearn.metrics import confusion_matrix
from typing import Optional
import os
import logging

# Set logging level to critical to reduce output noise
logging.getLogger("flwr").setLevel(logging.CRITICAL)

# Define device for computation (GPU if available, else CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLIENTS = 5  # Number of federated clients

In [None]:
# Load and preprocess data
def load_and_preprocess(train_file, val_file):
    """
    Load and preprocess the data from CSV files.

    Parameters:
        train_file (str): Path to the training CSV file.
        val_file (str): Path to the validation CSV file.

    Returns:
        Tuple containing training samples, training labels, validation samples, and validation labels.
    """
    # Read the data from CSV files
    train_df = pd.read_csv(train_file, header=None)
    val_df = pd.read_csv(val_file, header=None)
    
    # Extract samples and labels, convert them to float
    train_samples = train_df.iloc[1:-1, 1:].values.astype(float)
    train_labels = train_df.iloc[-1, 1:].values.astype(float)
    val_samples = val_df.iloc[1:-1, 1:].values.astype(float)
    val_labels = val_df.iloc[-1, 1:].values.astype(float)
    
    # Standardize the samples
    scaler = StandardScaler()
    train_samples = scaler.fit_transform(train_samples.T).T
    val_samples = scaler.transform(val_samples.T).T
    
    return train_samples, train_labels, val_samples, val_labels

In [None]:
# MLP model
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2):
        super(MLP, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, 1)
        self.sigmoid = nn.Sigmoid()  # Activation function for output layer
        self.dropout = nn.Dropout(p=0.3)  # Dropout layer for regularization

    def forward(self, x):
        # Define the forward pass
        x = self.sigmoid(self.fc1(x))
        x = self.dropout(x)
        x = self.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Early stopping criteria
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        """
        Initialize the early stopping mechanism.

        Parameters:
            patience (int): Number of epochs to wait for improvement before stopping.
            min_delta (float): Minimum change to qualify as an improvement.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        """
        Check if training should stop based on validation loss.

        Parameters:
            val_loss (float): Current validation loss.
        """
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [None]:
# Federated learning utility functions

def get_client_parameters(net) -> List[np.ndarray]:
    """
    Extract model parameters for federated learning.

    Parameters:
        net (nn.Module): The neural network model.

    Returns:
        List[np.ndarray]: Model parameters.
    """
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_client_parameters(net, parameters: List[np.ndarray]):
    """
    Set model parameters for federated learning.

    Parameters:
        net (nn.Module): The neural network model.
        parameters (List[np.ndarray]): List of model parameters.
    """
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

In [None]:
# Test the client model
def test_client_model(net, testloader):
    """
    Evaluate the model on the test set.

    Parameters:
        net (nn.Module): The neural network model.
        testloader (DataLoader): DataLoader for the test set.

    Returns:
        Tuple containing loss, accuracy, precision, recall, and F1 score.
    """
    criterion = nn.BCEWithLogitsLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()  # Set model to evaluation mode
    all_labels = []
    all_preds = []
    
    # Disable gradient calculation for inference
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = net(inputs).squeeze()
            loss += criterion(outputs, labels).item()
            preds = torch.sigmoid(outputs) > 0.5  # Apply sigmoid to get probabilities
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    
    # Calculate metrics
    loss /= len(testloader.dataset)
    accuracy = correct / total
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    return loss, accuracy, precision, recall, f1

In [None]:
# Train the client model
def train_client_model(net, trainloader, valloader, epochs: int, patience: int = 5, client_id: int = 0):
    """
    Train the model on the client's local dataset.

    Parameters:
        net (nn.Module): The neural network model.
        trainloader (DataLoader): DataLoader for the training set.
        valloader (DataLoader): DataLoader for the validation set.
        epochs (int): Number of training epochs.
        patience (int): Patience for early stopping.
        client_id (int): Identifier for the client.
    """
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    best_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

        net.train()  # Set model to training mode
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Evaluate on the validation set
        val_loss, val_accuracy, _, _, _ = test_client_model(net, valloader)
        print(f"Epoch {epoch+1}: val loss {val_loss:.4f}, val accuracy {val_accuracy:.4f}")

        # Save the best model
        if val_loss < best_loss:
            best_loss = val_loss
            patience_counter = 0
            torch.save(net.state_dict(), f"best_model_client_{client_id}.pt")
        else:
            patience_counter += 1

        # Stop training if 100% accuracy is achieved
        if val_accuracy == 1.0:
            print(f"Early stopping at epoch due to reaching 100% accuracy")
            break

In [None]:
# Define the FlowerClient class for federated learning
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        """
        Initialize a federated learning client.

        Parameters:
            cid (int): Client ID.
            net (nn.Module): The neural network model.
            trainloader (DataLoader): DataLoader for the training set.
            valloader (DataLoader): DataLoader for the validation set.
        """
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        """Return the model parameters."""
        print(f"[Client {self.cid}] get_parameters")
        return get_client_parameters(self.net)

    def fit(self, parameters, config):
        """Train the model on local data."""
        print(f"[Client {self.cid}] fit, config: {config}")
        set_client_parameters(self.net, parameters)
        train_client_model(self.net, self.trainloader, self.valloader, epochs=20, patience=5, client_id=self.cid)
        return get_client_parameters(self.net), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        """Evaluate the model on local validation data."""
        print(f"[Client {self.cid}] evaluate, config: {config}")
        # Set the model parameters with the latest global parameters
        set_client_parameters(self.net, parameters)
        # Evaluate the model
        loss, accuracy, precision, recall, f1 = test_client_model(self.net, self.valloader)
        return float(loss), len(self.valloader.dataset), {
            "accuracy": float(accuracy),
            "precision": float(precision),
            "recall": float(recall),
            "f1": float(f1)
        }

In [None]:
# Create a Flower client
def client_fn(cid) -> FlowerClient:
    """
    Create and initialize a FlowerClient instance for a given client ID.

    Parameters:
        cid (int): Client ID.

    Returns:
        FlowerClient: Configured FlowerClient instance.
    """
    # Initialize the neural network model with specified architecture
    net = MLP(input_size=300, hidden_size1=392, hidden_size2=392).to(DEVICE)
    # Define file paths for training and validation data
    train_file = f"../No_MSC/Seed_{seed_index}/Client_{cid}_Train.csv"
    val_file = f"../No_MSC/Seed_{seed_index}/Client_{cid}_Validation.csv"
    
    # Load and preprocess the data
    train_samples, train_labels, val_samples, val_labels = load_and_preprocess(train_file, val_file)
    
    # Convert samples and labels to PyTorch tensors
    X_train = torch.tensor(train_samples.T, dtype=torch.float32)
    y_train = torch.tensor(train_labels, dtype=torch.float32)
    X_val = torch.tensor(val_samples.T, dtype=torch.float32)
    y_val = torch.tensor(val_labels, dtype=torch.float32)

    # Create TensorDatasets and DataLoaders for training and validation
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    return FlowerClient(cid, net, train_loader, val_loader)

In [None]:
# Load combined validation dataset for evaluation
def load_combined_validation(seed_index):
    """
    Load and preprocess the combined validation dataset for all clients.

    Parameters:
        seed_index (int): Index of the seed for data partitioning.

    Returns:
        DataLoader: DataLoader for the combined validation dataset.
    """
    # Define file path for the combined validation set
    val_file = f"../No_MSC/Seed_{seed_index}/Combined_Validation.csv"
    # Load and preprocess the combined validation data
    _, _, val_samples, val_labels = load_and_preprocess(val_file, val_file)
    # Convert samples and labels to PyTorch tensors
    X_val = torch.tensor(val_samples.T, dtype=torch.float32)
    y_val = torch.tensor(val_labels, dtype=torch.float32)
    # Create TensorDataset and DataLoader
    val_dataset = TensorDataset(X_val, y_val)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    return val_loader

In [None]:
# Evaluate the aggregated model on the combined validation dataset
def evaluate_aggregated_model(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    """
    Evaluate the aggregated model on the combined validation dataset.

    Parameters:
        server_round (int): Current round of federated learning.
        parameters (fl.common.NDArrays): Aggregated model parameters from the server.
        config (Dict[str, fl.common.Scalar]): Configuration dictionary.

    Returns:
        Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: 
        Evaluation loss and a dictionary of aggregated evaluation metrics.
    """
    global best_aggregated_acc
    # Initialize the neural network model
    net = MLP(input_size=300, hidden_size1=392, hidden_size2=392).to(DEVICE)
    # Load the latest parameters into the model
    set_client_parameters(net, parameters)
    # Evaluate the model on the combined validation set
    loss, accuracy, precision, recall, f1 = test_client_model(net, combined_val_loader)
    print(f"Server-side evaluation - loss: {loss}, accuracy: {accuracy}")

    # Save the best model based on validation accuracy
    if accuracy > best_aggregated_acc:
        best_aggregated_acc = accuracy
        torch.save(net.state_dict(), "best_aggregated_model.pt")
        print(f"Best aggregated model saved with accuracy {accuracy}")

    return loss, {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

In [None]:
# Define metric aggregation function using arithmetic mean
def arithmetic_mean(metrics: List[Tuple[int, Dict[str, fl.common.Scalar]]]) -> Dict[str, fl.common.Scalar]:
    """
    Compute the arithmetic mean of metrics across all clients.

    Parameters:
        metrics (List[Tuple[int, Dict[str, fl.common.Scalar]]]): 
        List of tuples containing number of samples and metric dictionaries from clients.

    Returns:
        Dict[str, fl.common.Scalar]: Aggregated metrics.
    """
    # Extract accuracy and calculate mean
    accuracies = [m["accuracy"] for _, m in metrics]
    mean_acc = np.mean(accuracies)
    # Extract precision and calculate mean
    precisions = [m["precision"] for _, m in metrics]
    mean_precision = np.mean(precisions)
    # Extract recall and calculate mean
    recalls = [m["recall"] for _, m in metrics]
    mean_recall = np.mean(recalls)
    # Extract F1 score and calculate mean
    f1_scores = [m["f1"] for _, m in metrics]
    mean_f1 = np.mean(f1_scores)
    return {"accuracy": mean_acc, "precision": mean_precision, "recall": mean_recall, "f1": mean_f1}

In [None]:
# Training and evaluation loop for federated learning
# Initialize the combined validation dataset loader for each seed
combined_val_loaders = []

# Store metrics for analysis
all_metrics = []

# Train and evaluate the federated model across 30 seeds
for seed_index in range(30):
    print(f"Seed {seed_index+1}/30")

    # Load the combined validation dataset for the current seed
    combined_val_loader = load_combined_validation(seed_index)
    combined_val_loaders.append(combined_val_loader)

    # Reset best aggregated accuracy for each seed
    best_aggregated_acc = 0.0

    # Define the federated learning strategy
    strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,  # Use all clients for training in each round
        fraction_evaluate=1.0,  # Use all clients for evaluation in each round
        min_fit_clients=5,  # Minimum number of clients required for training
        min_evaluate_clients=5,  # Minimum number of clients required for evaluation
        min_available_clients=NUM_CLIENTS,  # Minimum number of total clients
        initial_parameters=fl.common.ndarrays_to_parameters(get_client_parameters(MLP(input_size=300, hidden_size1=392, hidden_size2=392))),
        evaluate_fn=evaluate_aggregated_model,  # Evaluation function for server-side evaluation
        evaluate_metrics_aggregation_fn=arithmetic_mean,  # Metric aggregation function
    )

    client_resources = None
    if DEVICE.type == "cuda":
        client_resources = {"num_gpus": 1}

    # Start the federated learning simulation
    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=NUM_CLIENTS,
        config=fl.server.ServerConfig(num_rounds=50),  # Number of training rounds
        strategy=strategy,
        client_resources=client_resources,
    )

    # Load and evaluate the best aggregated model for the current seed
    aggregated_model_path = "best_aggregated_model.pt"
    aggregated_model = MLP(input_size=300, hidden_size1=392, hidden_size2=392).to(DEVICE)
    if os.path.exists(aggregated_model_path):
        aggregated_model.load_state_dict(torch.load(aggregated_model_path))
        loss, accuracy, precision, recall, f1 = test_client_model(aggregated_model, combined_val_loader)
        print(f"Aggregated model evaluation - loss: {loss}, accuracy: {accuracy}, precision: {precision}, recall: {recall}, f1: {f1}")
        all_metrics.append((loss, accuracy, precision, recall, f1))

In [None]:
# Convert the metrics to a DataFrame for easier plotting
metrics_df = pd.DataFrame(all_metrics, columns=['val_loss', 'val_accuracy', 'val_precision', 'val_recall', 'val_f1'])

# Plot the box plots for each metric
plt.figure(figsize=(12, 8))

# Box plot for validation accuracy
plt.subplot(2, 3, 1)
sns.boxplot(y=metrics_df['val_accuracy'])
plt.title('Validation Accuracy')
plt.xlabel('Accuracy')
mean_val = metrics_df['val_accuracy'].mean()
plt.scatter(0, mean_val, color='red', s=100, zorder=10)

# Box plot for validation precision
plt.subplot(2, 3, 2)
sns.boxplot(y=metrics_df['val_precision'])
plt.title('Validation Precision')
plt.xlabel('Precision')
mean_val = metrics_df['val_precision'].mean()
plt.scatter(0, mean_val, color='red', s=100, zorder=10)

# Box plot for validation recall
plt.subplot(2, 3, 3)
sns.boxplot(y=metrics_df['val_recall'])
plt.title('Validation Recall')
plt.xlabel('Recall')
mean_val = metrics_df['val_recall'].mean()
plt.scatter(0, mean_val, color='red', s=100, zorder=10)

# Box plot for validation F1 score
plt.subplot(2, 3, 4)
sns.boxplot(y=metrics_df['val_f1'])
plt.title('Validation F1 Score')
plt.xlabel('F1 Score')
mean_val = metrics_df['val_f1'].mean()
plt.scatter(0, mean_val, color='red', s=100, zorder=10)

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()

In [None]:
# Save the metrics for later comparison
metrics_df.to_csv("fedMetrics/MLP_metrics.csv", index=False)

print("Finished Cross-Validation")