In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay
import numpy as np
from config import CONFIG
from model import DeepClassifier
from utils import load_data, preprocess_data, balance_data, split_data, create_clients
from train import train_local, average_models

def main():
    sns.set(style="whitegrid")
    
    # Load and preprocess
    train_paths = ['./data/training_dataset_1.csv', './data/training_dataset_2.csv']
    test_path = './data/test_dataset.csv'

    train_data, test_data = load_data(train_paths, test_path)
    X_train, y_train = preprocess_data(train_data)
    X_test, y_test = preprocess_data(test_data)

    X_train, y_train = balance_data(X_train, y_train)
    X_train, X_val, y_train, y_val = split_data(X_train, y_train, CONFIG['test_size'], CONFIG['random_state'])

    clients = create_clients(X_train, y_train, CONFIG['num_clients'])

    test_dataset = torch.utils.data.TensorDataset(torch.tensor(X_val.values, dtype=torch.float32),
                                                   torch.tensor(y_val.values, dtype=torch.long))
    test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False)

    global_model = DeepClassifier(
        CONFIG['input_size'],
        CONFIG['hidden_sizes'],
        CONFIG['dropout'],
        CONFIG['output_size']
    )

    # Metrics storage
    global_accuracies = []
    training_histories = []
    best_val_accuracy = 0.0
    best_model_path = './checkpoints/best_global_model.pt'

    os.makedirs("./checkpoints", exist_ok=True)

    rounds_bar = tqdm(range(CONFIG['global_rounds']), desc="🌍 Global Rounds", colour="green")

    for round_idx in rounds_bar:
        local_models = []
        round_history = []

        clients_bar = tqdm(enumerate(clients), total=len(clients), desc=f"🤝 Clients Round {round_idx+1}", leave=False, colour="blue")

        for i, client_data in clients_bar:
            model = DeepClassifier(CONFIG['input_size'],
                CONFIG['hidden_sizes'],
                CONFIG['dropout'],
                CONFIG['output_size'])
            model.load_state_dict(global_model.state_dict())
            train_loader = DataLoader(client_data, batch_size=CONFIG['batch_size'], shuffle=True)

            client_checkpoint_path = f"./checkpoints/client_{i}_round_{round_idx}.pt"

            history = train_local(model, train_loader, CONFIG, val_loader=test_loader, save_path=client_checkpoint_path)
            round_history.append(history)
            local_models.append(model)

        global_model = average_models(local_models)

        acc = evaluate(global_model, test_loader)
        global_accuracies.append(acc)

        training_histories.append(round_history)

        rounds_bar.set_postfix({"Validation Acc": f"{acc:.4f}"})

        if acc > best_val_accuracy:
            best_val_accuracy = acc
            torch.save(global_model.state_dict(), best_model_path)
            print(f"✅ New best model saved at round {round_idx+1} with accuracy {acc:.4f}")

    # Visualization
    plot_performance(global_accuracies, training_histories)

    # Load best model for final test evaluation
    print("\n🏆 Loading best model for final evaluation on test set...")
    best_model = DeepClassifier(
        CONFIG['input_size'],
        CONFIG['hidden_sizes'],
        CONFIG['dropout'],
        CONFIG['output_size']
    )
    best_model.load_state_dict(torch.load(best_model_path))
    best_model.eval()

    final_test_dataset = torch.utils.data.TensorDataset(
        torch.tensor(X_test.values, dtype=torch.float32),
        torch.tensor(y_test.values, dtype=torch.long)
    )
    final_test_loader = DataLoader(final_test_dataset, batch_size=CONFIG['batch_size'], shuffle=False)

    class_names = [str(i) for i in range(CONFIG['output_size'])]  
    
    final_accuracy = evaluate_full(best_model, final_test_loader, class_names=class_names)
    print(f"\n🎯 Final Test Accuracy on unseen data: {final_accuracy:.4f}")

def plot_performance(global_accuracies, training_histories):
    rounds = range(1, len(global_accuracies)+1)

    fig, axs = plt.subplots(3, 1, figsize=(16, 18))

    sns.lineplot(ax=axs[0], x=rounds, y=global_accuracies, marker='o', label='Validation Accuracy', linewidth=2, color="darkblue")
    axs[0].set_title('Validation Accuracy Over Global Rounds', fontsize=18)
    axs[0].set_xlabel('Global Round', fontsize=14)
    axs[0].set_ylabel('Accuracy', fontsize=14)
    axs[0].set_xticks(rounds)
    axs[0].set_ylim(0, 1)
    axs[0].legend()
    axs[0].grid(True)

    for client_histories in training_histories:
        for history in client_histories:
            if history['val_loss']:
                sns.lineplot(ax=axs[1], y=history['val_loss'], label='Val Loss (client)', linewidth=1)
    axs[1].set_title('Validation Loss Over Epochs', fontsize=18)
    axs[1].set_xlabel('Epoch', fontsize=14)
    axs[1].set_ylabel('Loss', fontsize=14)
    axs[1].grid(True)

    for client_histories in training_histories:
        for history in client_histories:
            if history['lr']:
                sns.lineplot(ax=axs[2], y=history['lr'], label='Learning Rate (client)', linewidth=1)
    axs[2].set_title('Learning Rate Over Epochs', fontsize=18)
    axs[2].set_xlabel('Epoch', fontsize=14)
    axs[2].set_ylabel('Learning Rate', fontsize=14)
    axs[2].grid(True)

    plt.tight_layout()
    plt.show()

def evaluate_full(model, data_loader, class_names=None):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            outputs = model(X_batch)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y_batch.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    acc = np.mean(all_preds == all_labels)

    cm = confusion_matrix(all_labels, all_preds)
    print("\n📊 Confusion Matrix:")
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap='Blues', xticks_rotation=45)
    plt.show()

    print("\n📋 Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))

    return acc

In [None]:
if __name__ == "__main__":
    main()