## Parameter Setup

In [1]:
import sys
import os
from datetime import datetime

local_update_number = 5
model = 'mnist' # mnist or mnist_cnn

# Move to the src directory
os.chdir('../src')

# Main execution block
fedprox_mu_values = [0.5, 1.5, 2.5]  # Specify mu values for FedProx

# Automatically create the methods to evaluate based on the given mu values for FedProx
def create_methods_to_evaluate(fedprox_mu_values, fedprox_exclusive=False):
    """Create a list of methods to evaluate, including FedAvg, LocalTrain, and FedProx with specific mu values."""
    methods_to_evaluate = ["FedAvg", "LocalTrain"] if not fedprox_exclusive else []
    for mu in fedprox_mu_values:
        methods_to_evaluate.append(f"FedProx_mu{mu}")
    return methods_to_evaluate

# Specify the methods you want to evaluate
methods_to_evaluate = create_methods_to_evaluate(fedprox_mu_values)
methods_fedprox = create_methods_to_evaluate(fedprox_mu_values, fedprox_exclusive=True)

# Set the directory where the result files are loaded and stored
result_directory = f'../results/{model}/E{local_update_number}' 

## Collect Results

In [None]:
import os
import re
import pandas as pd

# Define the metrics we're interested in
metric_names = [
    "Global Train Loss",
    "Global Test Loss",
    "Global Test Accuracy",
    "Global Evaluation Loss",
    "Global Evaluation Accuracy",
    "Average Evaluation Loss",
    "Average Evaluation Accuracy",
    "Average Naive Baseline Loss",
    "Average Naive Baseline Accuracy",
    "Average Evaluation Improvement Over Baseline",
    "Average Training Loss",
    "Average Training Accuracy",
    "Average Test Loss",
    "Average Test Accuracy"
]


# Create a dictionary to store the results, with the structure:
# {client_classes: {metric: {method: value}}}
results = {}

# Function to extract relevant data from the file content
def extract_results_from_file(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    
    # Define the regex patterns for the rows of interest
    patterns = {
        "Global Train Loss": r"Global Train Loss across all clients: (.+)",
        "Global Test Loss": r"Global Test Loss: (.+)",
        "Global Test Accuracy": r"Global Test Accuracy: (.+)",
        "Global Evaluation Loss": r"Global Evaluation Loss: (.+)",
        "Global Evaluation Accuracy": r"Global Evaluation Accuracy: (.+)",
        "Average Evaluation Loss": r"Average evaluation loss across clients: (.+)",
        "Average Evaluation Accuracy": r"Average evaluation accuracy across clients: (.+)",
        "Average Naive Baseline Loss": r"Average naive baseline loss across clients: (.+)",
        "Average Naive Baseline Accuracy": r"Average naive baseline accuracy across clients: (.+)",
        "Average Evaluation Improvement Over Baseline": r"Average improvement over naive baseline: (.+)",
        "Average Training Loss": r"Average training loss across clients: (.+)",
        "Average Training Accuracy": r"Average training accuracy across clients: (.+)",
        "Average Test Loss": r"Average test loss across clients: (.+)",
        "Average Test Accuracy": r"Average test accuracy across clients: (.+)"
    }

    # Extract the relevant data
    extracted_data = {}
    for line in lines:
        for key, pattern in patterns.items():
            match = re.search(pattern, line)
            if match:
                extracted_data[key] = match.group(1)  # Extract the first matched group

    return extracted_data

# Function to process all result files and populate the 'results' dictionary
def process_results_in_directory(directory, methods_to_evaluate):
    for file_name in os.listdir(directory):
        if file_name.endswith(".txt"):
            # Extract method and client classes from filename (assuming consistent naming convention)
            method_match = re.search(r"output_log_(.+?)(_mu[\d.]+)?_(\d+)cls.txt", file_name)
            if method_match:
                method = method_match.group(1)  # Method name (e.g., FedAvg, LocalTrain)
                mu_value = method_match.group(2)  # Optional mu value (e.g., _mu0.1)
                client_classes = int(method_match.group(3))  # Client classes
                
                # If mu_value is not None, append it to the method name
                if mu_value:
                    method += mu_value

                # Only process files for methods we're interested in
                if method not in methods_to_evaluate:
                    continue

                # Full path to the file
                file_path = os.path.join(directory, file_name)

                # Extract results from the file
                extracted_data = extract_results_from_file(file_path)

                # Add data to the results dictionary
                if client_classes not in results:
                    results[client_classes] = {metric: {} for metric in metric_names}
                
                for metric, value in extracted_data.items():
                    results[client_classes][metric][method] = value

# Process all the result files for the specified methods
process_results_in_directory(result_directory, methods_to_evaluate)

# Create a DataFrame to store the final output
data = []

# Loop over the results dictionary to populate the DataFrame rows
for client_classes, metrics in results.items():
    for metric, methods in metrics.items():
        row = [client_classes, metric]
        for method in methods_to_evaluate:
            row.append(methods.get(method, None))  # Get the method value or None if missing
        data.append(row)

# Create the DataFrame with columns for client classes, metric, and each method
columns = ["client_classes", "metric"] + methods_to_evaluate
df = pd.DataFrame(data, columns=columns)

# Display the table
print(df)

# Optionally, save to a CSV
df.to_csv('../results/mnist/summary_results.csv', index=False)

## Visualization

### Statistical Error

In [3]:
import matplotlib.pyplot as plt
import re
import os

def plot_metric_vs_client_classes(results, metric_name, methods_to_evaluate, output_dir, selected_client_classes=None, 
                                  x_label="Statistical Heterogeneity (# Classes)", y_label = None, flip_x_axis=False):
    """
    Plots the specified metric versus client classes for different methods, with the option to flip the x-axis.

    Args:
        results (dict): A dictionary containing the results for each method and client class.
        metric_name (str): The metric to plot (e.g., "Global Train Loss").
        methods_to_evaluate (list): A list of method names to include in the plot.
        output_dir (str): Directory where to save the plots.
        selected_client_classes (list, optional): A list of client classes to include in the plot. 
                                                  If None, all client classes will be used.
        x_label (str): Label for the x-axis (default is "Number of Classes per Client").
        flip_x_axis (bool): If True, invert the x-axis so larger numbers appear on the left.
    """
    # Set the figure size, marker size, font size, and transparency for real data plots
    fig_size = (4, 4)
    marker_size = 4
    font_size = 10
    legend_font_size = 8
    default_alpha = 1.0
    grid_alpha = 0.3

    # Get colors from the 'tab10' color palette
    tab10_colors = plt.get_cmap('tab10').colors

    # Define method styles with colors and markers
    method_styles = {
        'FedAvg': {'color': tab10_colors[0], 'linestyle': '--', 'marker': 'o', 'label': 'FedAvg', 'alpha': 0.9},  # Blue
        'LocalTrain': {'color': tab10_colors[7], 'linestyle': '--', 'marker': 's', 'label': 'LocalTrain', 'alpha': default_alpha}  # Grey
    }

    # Extract FedProx methods with mu values
    fedprox_methods = [method for method in methods_to_evaluate if 'FedProx' in method]
    mu_values = []

    for method in fedprox_methods:
        match = re.search(r'_mu([0-9.]+)', method)
        if match:
            mu_values.append((method, float(match.group(1))))

    # Sort FedProx methods by their lambda (mu) values
    mu_values.sort(key=lambda x: x[1])
    mu_labels = ['small lambda', 'medium lambda', 'large lambda']
    fedprox_colors = [tab10_colors[2], tab10_colors[1], tab10_colors[3]]  # Green, Orange, Red
    fedprox_markers = ['D', '*', 'x']  # Diamond, Star, Multiply

    # Assign styles to FedProx methods
    for (method, _), label, color, marker in zip(mu_values, mu_labels, fedprox_colors, fedprox_markers):
        alpha = 0.9 if label == 'large lambda' else default_alpha
        method_styles[method] = {
            'color': color, 'linestyle': '-', 'marker': marker, 'label': f'FedProx ({label})', 'alpha': alpha
        }

    # If selected_client_classes is not provided, use all available client classes
    client_classes_list = sorted(selected_client_classes) if selected_client_classes else sorted(results.keys())
    
    # Create a dictionary to hold metric data for each method
    metric_data = {method: [] for method in methods_to_evaluate}

    # Extract the specified metric for each selected client class and method
    for client_classes in client_classes_list:
        for method in methods_to_evaluate:
            # Get the metric value for the current client classes and method
            metric_value = results.get(client_classes, {}).get(metric_name, {}).get(method, None)
            # Append the value or None if it's missing
            metric_data[method].append(float(metric_value) if metric_value is not None else None)

    # Plot the specified metric for different methods
    plt.figure(figsize=fig_size)
    for method in methods_to_evaluate:
        style = method_styles.get(method, {})
        plt.plot(client_classes_list, metric_data[method], color=style['color'], linestyle=style['linestyle'],
                 marker=style['marker'], markersize=marker_size, label=style['label'], alpha=style['alpha'])

    plt.xlabel(x_label, fontsize=font_size)
    plt.ylabel(metric_name if y_label is None else y_label, fontsize=font_size)

    # Remove top and right borders
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)

    # Invert the x-axis if flip_x_axis is True
    if flip_x_axis:
        plt.gca().invert_xaxis()

    # Position the legend below the plot
    plt.legend(fontsize=legend_font_size, loc='upper center', bbox_to_anchor=(0.5, -0.25), ncol=2)

    # Add a grid with transparency
    plt.grid(True, alpha=grid_alpha)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{metric_name}_vs_Client_Classes.png'), bbox_inches='tight')
    plt.close()

In [4]:
plotting_dir = os.path.join(result_directory, "figures")
os.makedirs(plotting_dir, exist_ok=True)

# Plotting "Global Evaluation Loss" versus client classes
plot_metric_vs_client_classes(
    results=results,
    metric_name="Global Evaluation Loss", 
    methods_to_evaluate=methods_to_evaluate,
    selected_client_classes=[2, 4, 6, 8, 10],
    output_dir=plotting_dir,
    flip_x_axis=True
)

# Plotting "Average Evaluation Loss" versus client classes
plot_metric_vs_client_classes(
    results=results,
    metric_name="Average Evaluation Loss", 
    methods_to_evaluate=methods_to_evaluate,
    selected_client_classes=[2, 4, 6, 8, 10],
    output_dir=plotting_dir,
    flip_x_axis=True
)

# Plotting "Global Train Loss" versus client classes
plot_metric_vs_client_classes(
    results=results,
    metric_name="Global Train Loss", 
    methods_to_evaluate=methods_to_evaluate,
    selected_client_classes=[2, 4, 6, 8, 10],
    output_dir=plotting_dir,
    flip_x_axis=True
)

# Plotting "Average Training Loss" versus client classes
plot_metric_vs_client_classes(
    results=results,
    metric_name="Average Training Loss", 
    methods_to_evaluate=methods_to_evaluate,
    selected_client_classes=[2, 4, 6, 8, 10],
    output_dir=plotting_dir,
    flip_x_axis=True
)

# Plotting "Global Evaluation Accuracy" versus client classes
plot_metric_vs_client_classes(
    results=results,
    metric_name="Global Evaluation Accuracy", 
    methods_to_evaluate=methods_to_evaluate,
    selected_client_classes=[2, 4, 6, 8, 10],
    output_dir=plotting_dir,
    y_label="Global Test Accuracy",
    flip_x_axis=True
)

# Plotting "Average Evaluation Accuracy" versus client classes
plot_metric_vs_client_classes(
    results=results,
    metric_name="Average Evaluation Accuracy", 
    methods_to_evaluate=methods_to_evaluate,
    selected_client_classes=[2, 4, 6, 8, 10],
    output_dir=plotting_dir,
    flip_x_axis=True
)

# Plotting "Average Evaluation Improvement Over Baseline" versus client classes
plot_metric_vs_client_classes(
    results=results,
    metric_name="Average Evaluation Improvement Over Baseline", 
    methods_to_evaluate=methods_to_evaluate,
    selected_client_classes=[2, 4, 6, 8, 10],
    output_dir=plotting_dir,
    flip_x_axis=True
)

### Total Error

In [5]:
import os
import re
import matplotlib.pyplot as plt

# Function to extract a specific metric (Global Train Loss, Local Train Loss, etc.) from the file content
def extract_metric(file_path, metric_name):
    """
    Extract the specified metric from the file content.

    Args:
        file_path (str): The path to the file containing results.
        metric_name (str): The name of the metric to extract.

    Returns:
        list: A list of values for the specified metric.
    """
    with open(file_path, 'r') as f:
        lines = f.readlines()

    # Look for the line that contains the metric of interest
    for line in lines:
        if f"{metric_name}:" in line:
            # Extract and return the list of comma-separated values as floats
            metric_str = line.strip().replace(f"{metric_name}: ", "")
            metric_values = list(map(float, metric_str.split(',')))
            return metric_values
    return []

# Function to process the results in the directory and store the specified metric for each method
def process_metrics_in_directory(directory, methods_to_evaluate, client_classes, metric_name):
    metric_dict = {}

    for file_name in os.listdir(directory):
        if file_name.endswith(".txt"):
            # Extract method and client classes from filename (assuming consistent naming convention)
            method_match = re.search(r"output_log_(.+?)_(\d+)cls.txt", file_name)
            if method_match:
                method = method_match.group(1)
                file_client_classes = int(method_match.group(2))

                # Only process the specified number of client classes
                if file_client_classes != client_classes:
                    continue

                # Only process files for methods we're interested in
                if method not in methods_to_evaluate:
                    continue

                # Full path to the file
                file_path = os.path.join(directory, file_name)

                # Extract the metric from the file
                metric_values = extract_metric(file_path, metric_name)

                # Store the metric under the corresponding method
                if metric_values:
                    metric_dict[method] = metric_values

    return metric_dict

# Plotting function to plot the specified metric for different methods
def plot_metric(metric_dict, metric_name, client_classes, output_dir, ylabel=None):
    """
    Plot the specified metric for each method as a function of communication rounds.

    Args:
        metric_dict (dict): Dictionary with methods as keys and metric values as lists.
        metric_name (str): The name of the metric to plot.
        client_classes (int): The number of client classes for the experiment.
        output_dir (str): Directory to save the plot.
    """
    # Set the figure size and marker size
    fig_size = (4, 4)
    marker_size = 4
    font_size = 10
    legend_font_size = 8
    grid_alpha = 0.3  # Transparency for the grid

    # Get colors from the 'tab10' color palette
    tab10_colors = plt.get_cmap('tab10').colors

    # Define method styles with colors and markers
    method_styles = {
        'FedAvg': {'color': tab10_colors[0], 'linestyle': '--', 'label': 'FedAvg'},  # Blue
        'LocalTrain': {'color': tab10_colors[7], 'linestyle': '--', 'label': 'LocalTrain'},  # Grey
    }

    # Assign specific colors and markers to FedProx methods based on their lambda values
    fedprox_methods = [method for method in metric_dict.keys() if 'FedProx' in method]
    mu_values = []

    for method in fedprox_methods:
        match = re.search(r'_mu([0-9.]+)', method)
        if match:
            mu_values.append((method, float(match.group(1))))

    # Sort FedProx methods by their lambda values
    mu_values.sort(key=lambda x: x[1])
    mu_labels = ['small lambda', 'medium lambda', 'large lambda']
    fedprox_colors = [tab10_colors[2], tab10_colors[1], tab10_colors[3]]  # Green, Orange, Red

    # Assign styles 
    # o the sorted FedProx methods
    for (method, _), label, color in zip(mu_values, mu_labels, fedprox_colors):
        method_styles[method] = {
            'color': color, 'linestyle': '-', 'label': f'FedProx ({label})'
        }

    # Plot the metric as a function of communication rounds
    plt.figure(figsize=fig_size)
    for method, values in metric_dict.items():
        style = method_styles.get(method, {})
        plt.plot(range(1, len(values) + 1), values, color=style['color'], linestyle=style['linestyle'],
                 label=style['label'], markersize=marker_size)

    plt.xlabel('# Communication Rounds', fontsize=font_size)
    plt.ylabel(metric_name if ylabel is None else ylabel, fontsize=font_size)

    # Remove top and right borders
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)

    # Position the legend below the plot
    plt.legend(fontsize=legend_font_size, loc='upper center', bbox_to_anchor=(0.5, -0.25), ncol=2)

    # Add a grid with transparency
    plt.grid(True, alpha=grid_alpha)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{metric_name}_vs_Comm_Rounds_{client_classes}cls.png'), bbox_inches='tight')
    plt.close()

In [6]:
# Define Total Error plot methods
fedprox_mu_values = [0.5, 1.5, 2.5]
methods_to_evaluate = create_methods_to_evaluate(fedprox_mu_values, fedprox_exclusive=True)

# Specify the number of client classes for which you want to plot the results
client_classes = 2

# Define where to save the plots
plotting_dir = os.path.join(result_directory, "figures")

# For Train Losses
metric_dict = process_metrics_in_directory(result_directory, methods_fedprox, client_classes, "Global Train losses")
# Plot the specified metric over communication rounds
plot_metric(metric_dict, "Global Train losses", client_classes, plotting_dir, ylabel="Global Training Loss")

# For Train Losses
metric_dict = process_metrics_in_directory(result_directory, methods_fedprox, client_classes, "Average Client Training Losses")
# Plot the specified metric over communication rounds
plot_metric(metric_dict, "Average Client Training Losses", client_classes, plotting_dir, ylabel="Local Training Loss")

# For Evaluation Losses
metric_dict = process_metrics_in_directory(result_directory, methods_fedprox, client_classes, "Global Evaluation losses")
# Plot the specified metric over communication rounds
plot_metric(metric_dict, "Global Evaluation losses", client_classes, plotting_dir, ylabel="Global Test Loss")

# For Evaluation Losses
metric_dict = process_metrics_in_directory(result_directory, methods_fedprox, client_classes, "Average Client Evaluation losses")
# Plot the specified metric over communication rounds
plot_metric(metric_dict, "Average Client Evaluation losses", client_classes, plotting_dir, ylabel="Local Test Loss")

### Local Epochs

In [7]:
import os
import re
import matplotlib.pyplot as plt

# Function to extract a specific metric (Global Train Loss, Local Train Loss, etc.) from the file content
def extract_metric(file_path, metric_name):
    """
    Extract the specified metric from the file content.

    Args:
        file_path (str): The path to the file containing results.
        metric_name (str): The name of the metric to extract.

    Returns:
        list: A list of values for the specified metric.
    """
    with open(file_path, 'r') as f:
        lines = f.readlines()

    # Look for the line that contains the metric of interest
    for line in lines:
        if f"{metric_name}:" in line:
            # Extract and return the list of comma-separated values as floats
            metric_str = line.strip().replace(f"{metric_name}: ", "")
            metric_values = list(map(float, metric_str.split(',')))
            return metric_values
    return []

# Function to process the results in the directory and store the specified metric for the selected method and E
def process_metrics_for_method(directory, method_to_evaluate, local_update_numbers, metric_name, target_client_classes):
    metric_dict = {}

    for local_update_number in local_update_numbers:
        subdirectory = os.path.join(directory, f"E{local_update_number}")
        if os.path.exists(subdirectory):
            for file_name in os.listdir(subdirectory):
                if file_name.endswith(".txt"):
                    # Extract method and client classes from filename (assuming consistent naming convention)
                    method_match = re.search(r"output_log_(.+?)_(\d+)cls.txt", file_name)
                    if method_match:
                        method = method_match.group(1)
                        file_client_classes = int(method_match.group(2))
                        # Only process the specified method
                        if method != method_to_evaluate or file_client_classes != target_client_classes:
                            continue


                        # Full path to the file
                        file_path = os.path.join(subdirectory, file_name)

                        # Extract the metric from the file
                        metric_values = extract_metric(file_path, metric_name)

                        # Store the metric under the corresponding local update number
                        if metric_values:
                            metric_dict[local_update_number] = metric_values

    return metric_dict

# Plotting function to plot the specified metric for the selected method over different local update numbers (E values)
def plot_metric_over_local_updates(metric_dict, metric_name, method, output_dir, ylabel=None):
    """
    Plot the specified metric for the selected method over different local update numbers (E values).

    Args:
        metric_dict (dict): Dictionary with local update numbers as keys and metric values as lists.
        metric_name (str): The name of the metric to plot.
        method (str): The method being evaluated.
        output_dir (str): Directory to save the plot.
    """
    # Set the figure size and marker size
    fig_size = (4, 4)
    marker_size = 4
    font_size = 10
    legend_font_size = 8
    grid_alpha = 0.3  # Transparency for the grid

    # Get colors from the 'tab10' color palette
    tab10_colors = plt.get_cmap('tab10').colors

    # Plot the metric as a function of communication rounds for each E value
    plt.figure(figsize=fig_size)
    for local_update_number, values in metric_dict.items():
        plt.plot(range(1, len(values) + 1), values, label=f"# Local Epochs = {local_update_number}", markersize=marker_size)

    plt.xlabel('# Communication Rounds', fontsize=font_size)
    plt.ylabel(metric_name if ylabel is None else ylabel, fontsize=font_size)

    # Remove top and right borders
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)

    # Position the legend below the plot
    plt.legend(fontsize=legend_font_size, loc='upper center', bbox_to_anchor=(0.5, -0.25), ncol=2)

    # Add a grid with transparency
    plt.grid(True, alpha=grid_alpha)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{metric_name}_vs_Comm_Rounds_{method}.png'), bbox_inches='tight')
    plt.close()


In [8]:
# Example usage
result_directory = f'../results/{model}' 
plotting_dir = os.path.join(result_directory, "figures")

# Specify the method and the range of E values (local updates) to evaluate
method_to_evaluate = "FedProx_mu0.5"
local_update_numbers = [5, 10, 15, 20]  # Example set of E values
metric_name = "Global Train losses"
target_client_classes = 5

# Process the result files for the specified method and E values
metric_dict = process_metrics_for_method(result_directory, method_to_evaluate, local_update_numbers, metric_name, target_client_classes=target_client_classes)

# Plot the specified metric over communication rounds for different local update numbers
plot_metric_over_local_updates(metric_dict, metric_name, method_to_evaluate, plotting_dir, ylabel="Global Training Loss")

In [9]:
# Example usage
result_directory = f'../results/{model}' 
plotting_dir = os.path.join(result_directory, "figures")

# Specify the method and the range of E values (local updates) to evaluate
method_to_evaluate = "FedProx_mu2.5"
local_update_numbers = [5, 10, 15, 20]  # Example set of E values
metric_name = "Average Client Training Losses"
target_client_classes = 5

# Process the result files for the specified method and E values
metric_dict = process_metrics_for_method(result_directory, method_to_evaluate, local_update_numbers, metric_name, target_client_classes=target_client_classes)

# Plot the specified metric over communication rounds for different local update numbers
plot_metric_over_local_updates(metric_dict, metric_name, method_to_evaluate, plotting_dir, ylabel="Local Training Loss")