In [None]:
from src import datahandler
import src.models as models
from src.models.fcn import FCN_Predictor
import src.utils as utils
import src.server as server
import src.client as client
import flwr as fl
import torch

NUM_CLIENTS = 10
FL_ROUNDS = 10
DEVICE = 'cuda'
DATASET_CONFIG = {
    'dataset': 'london_smartmeter',
    'train_stride': 1,
    'validation_stride': 24,
    'observation_days': 1,
    'future_days': 1,
    'normalize': 'minmax',
}

MODEL_CONFIG = {
        '_model': FCN_Predictor,
        'hidden_size': 64,
        '_attack_step_multiplier': 1,
}

trainsets, valsets, testsets = datahandler.get_datasets(**DATASET_CONFIG, columns=0)
trainset, valset, testset = trainsets[0], valsets[0], testsets[0]

client_resources = {"num_cpus": 1, "num_gpus": 0.1}

model = MODEL_CONFIG['_model'](features=[0], hidden_size=MODEL_CONFIG['hidden_size'], input_size=trainset.freq_in_day*DATASET_CONFIG['observation_days'], output_size=trainset.freq_in_day*DATASET_CONFIG['future_days'])
model_parameters = utils.get_model_parameters(model)

# Create strategy
strategy = server.CustomStrategy(
    on_fit_config_fn=server.fit_config,
    evaluate_metrics_aggregation_fn=server.evaluate_metrics_aggregation,
    initial_parameters=fl.common.ndarrays_to_parameters(model_parameters),
)

history = fl.simulation.start_simulation(
    client_fn=client.client_factory(d_config=DATASET_CONFIG, m_config=MODEL_CONFIG, device=DEVICE, num_clients=NUM_CLIENTS),
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=FL_ROUNDS),  
    strategy=strategy,
    client_resources=client_resources,
    ray_init_args={"num_cpus": 64, "num_gpus": 1},
)

In [None]:
!netstat -ano | findstr :8080

In [None]:
# import subprocess
# import time
# import threading

# # Function to continuously read from a process' output
# def read_output(process, label):
#     for line in iter(process.stdout.readline, ''):
#         print(f"{label}: {line}", end='')

# # Function to monitor the processes
# def monitor_processes(server, clients):
#     while True:
#         # Check if server has exited
#         if server.poll() is not None:
#             print("Server has exited unexpectedly. Terminating all processes.")
#             terminate_all_processes(server, clients)
#             break

#         # Check if any client has exited
#         for client in clients:
#             if client.poll() is not None:
#                 print(f"Client {clients.index(client)} has exited unexpectedly. Terminating all processes.")
#                 terminate_all_processes(server, clients)
#                 break

#         # Sleep for a short time to avoid constant polling
#         time.sleep(0.5)

# # Function to terminate all processes
# def terminate_all_processes(server, clients):
#     server.terminate()
#     for client in clients:
#         client.terminate()

# print("Starting server...")
# # Start the server
# server = subprocess.Popen(
#     ["python", "src/server.py"], 
#     stdout=subprocess.PIPE, 
#     stderr=subprocess.STDOUT,
#     text=True,
#     bufsize=1
# )
# threading.Thread(target=read_output, args=(server, "Server"), daemon=True).start()

# # Wait for the server to initialize
# time.sleep(3)
# print("Server started.")

# # Start clients
# clients = []
# for i in range(2):
#     print(f"Starting client {i}...")
#     client = subprocess.Popen(
#         ["python", "src/client.py", f"--partition={i}", "--use_cuda=True"], 
#         stdout=subprocess.PIPE, 
#         stderr=subprocess.STDOUT,
#         text=True,
#         bufsize=1
#     )
#     threading.Thread(target=read_output, args=(client, f"Client {i}"), daemon=True).start()
#     clients.append(client)
#     print(f"Started client {i}")

# # Start monitoring thread
# monitor_thread = threading.Thread(target=monitor_processes, args=(server, clients))
# monitor_thread.start()

# # Use try-except to handle interrupts
# try:
#     # Wait for the monitor thread to finish
#     monitor_thread.join()
# except KeyboardInterrupt:
#     print("Terminating all processes due to KeyboardInterrupt...")
#     terminate_all_processes(server, clients)
#     print("Terminated all processes.")