In [None]:
import os
import torch
from client import ClientTrainer
from edge_server import EdgeServer
from cloud_server import CloudServer
from dataset import load_and_split_dataset, get_client_dataloaders
from dp_utils import get_dp_training_config
from models import load_autoencoder_encoder

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Federated setup
    num_edges = 2
    clients_per_edge = 3
    num_clients = num_edges * clients_per_edge

    # Differential Privacy config
    noise_multiplier = 1.1       
    max_grad_norm = 1.0           

    # Load dataset
    print("Loading and preprocessing CICIoMT2024 dataset...")
    train_dataset, test_dataset = load_and_split_dataset()
    client_dataloaders = get_client_dataloaders(train_dataset, test_dataset, num_clients)

    # Load pre-trained Autoencoder encoder
    print("Loading pre-trained Autoencoder encoder...")
    encoder = load_autoencoder_encoder().to(device)

    # Train local clients
    client_states = []
    for i, (train_loader, test_loader) in enumerate(client_dataloaders):
        print(f"[Client {i+1}] Training with Differential Privacy...")
        trainer = ClientTrainer(
            encoder=encoder,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
            noise_multiplier=noise_multiplier,
            max_grad_norm=max_grad_norm
        )
        client_model_state = trainer.train()
        client_states.append(client_model_state)
        torch.save(client_model_state, f"client_{i}_params.pt")

    # Aggregate at Edge Servers
    print("Aggregating models at Edge Servers...")
    edge_server = EdgeServer(client_states, clients_per_edge)
    edge_models = edge_server.aggregate()

    # Aggregate at Cloud Server
    print("Aggregating global model at Cloud Server...")
    cloud_server = CloudServer(edge_models)
    global_model = cloud_server.aggregate()
    torch.save(global_model, "global_model.pt")

    print("Training pipeline completed. Global model saved as global_model.pt")

if __name__ == "__main__":
    main()
