## Parameter Setup and Experiments

In [None]:
import subprocess
import numpy as np
import os; os.chdir('../src') # For server
from datetime import datetime

# Parameters for all methods
methods = ['FedAvg', 'LocalTrain', 'FedProx']
mu_values = [0.02, 0.1, 0.5]
L = 9.13  # Smoothness constant L used in the formula

# Create mu_global_lr_dict
mu_global_lr_dict = {mu: (1/mu + 1/L) for mu in mu_values}
# Create mu_local_lr_dict
mu_local_lr_dict = {mu: (1/(L + mu)) for mu in mu_values}
lr_local_localtrain = 0.001 # lr_local * mu < 1
lr_local_fedavg = 0.01

random_seed = 1
# FOR ATR
R_values_0_to_2 = np.linspace(0.1, 3, 10).tolist()
# Combined
R_combined = R_values_0_to_2

num_clients = 10
num_samples = 200
input_dim = 10
num_classes = 2
local_epochs = 150
data_dir = '../data/fedprox_syndata/test'
output_data_dir = '../results/test'
stopping_threshold = -1
num_rounds = 50  # Number of rounds, LOCAL TRAIN convergence issue
local_train_base_rounds = 500 #### TEST

# Parameters for data generation
num_devices = num_clients
x_dim = input_dim
b_dim = num_classes

start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print(f"Start time: {start_time}")

# Loop through each choice of R
for R in R_combined:
    # Generate data for the current R with specified dimensions and number of devices
    print(f"Generating data for R={R} with {num_devices} devices...")
    subprocess.run(f"python simulation/generate_data.py --R {R} --num_devices {num_devices} --n_samples {num_samples} "
                   f"--x_dim {x_dim} --b_dim {b_dim} --seed {random_seed} "
                   f"--output_dir {data_dir}", shell=True, check=True)

    # Loop through each method
    for method in methods:
        if method == 'FedProx':
            # For FedProx, iterate over different mu values
            for mu in mu_global_lr_dict.keys():
                global_lr_fedprox = mu_global_lr_dict[mu]
                local_lr_fedprox = mu_local_lr_dict[mu]
                # Single run for FedProx with specified mu
                print(f"Running {method} with mu={mu} for R={R}...")
                subprocess.run(f"python simulation/simulation_main.py --method {method} --num_clients {num_clients} "
                               f"--lr_global {global_lr_fedprox} --lr_local {local_lr_fedprox} --mu {mu} "
                               f"--input_dim {input_dim} --num_classes {num_classes} --local_epochs {local_epochs} "
                               f"--data_dir {data_dir} --output_data_dir {output_data_dir} "
                               f"--stopping_threshold {stopping_threshold} --num_rounds {num_rounds} "
                               f"--R {R} --record_error stat", shell=True, check=True)
        else:
            if method == 'LocalTrain':
                local_train_num_rounds = local_train_base_rounds # + int(2**R * 100) # No longer give more iterations
                # For FedAvg and LocalTrain
                print(f"Running {method} for R={R}...")
                subprocess.run(f"python simulation/simulation_main.py --method {method} --num_clients {num_clients} "
                               f"--lr_local {lr_local_localtrain} "
                               f"--input_dim {input_dim} --num_classes {num_classes} --local_epochs {local_epochs} "
                               f"--data_dir {data_dir} --output_data_dir {output_data_dir} "
                               f"--stopping_threshold {stopping_threshold} --num_rounds {local_train_num_rounds} "
                               f"--R {R} --record_error stat", shell=True, check=True)
            else:
                print(f"Running {method} for R={R}...")
                subprocess.run(f"python simulation/simulation_main.py --method {method} --num_clients {num_clients} "
                               f"--lr_local {lr_local_fedavg} "
                               f"--input_dim {input_dim} --num_classes {num_classes} --local_epochs {local_epochs} "
                               f"--data_dir {data_dir} --output_data_dir {output_data_dir} "
                               f"--stopping_threshold {stopping_threshold} --num_rounds {num_rounds} "
                               f"--R {R} --record_error stat", shell=True, check=True)
            

end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print(f"End time: {end_time}")

## Visualize the Result

In [None]:
# FOR ATR
# Convert the lists into space-separated strings
R_values_str = ' '.join(map(str, R_values_0_to_2))
mu_values_str = ' '.join(map(str, mu_values))
methods_str = ' '.join(methods)

# Command for subprocess
command = f"python simulation/visualize_stat_error.py --output_data_dir {output_data_dir} --R_values {R_values_str} --methods {methods_str} --mu_values {mu_values_str}"

# Run the command using subprocess
subprocess.run(command, shell=True, check=True)