In [None]:
import argparse
import yaml
import os
import tensorflow as tf
import sys
import numpy as np
import random
import pandas as pd
import timeit # For timing in FlirtServer's final metrics
import shutil # For clearing temp folder if needed in setup

# Ensure paths are correct for imports
# Add the project root to the Python path if running from an arbitrary location
# This assumes the notebook is at the root or you've added the root to the path
if '../' not in sys.path:
    sys.path.insert(0, '../') # Adjust if your notebook is deeper or shallower

from src.utils.data_utils import load_and_split_data, split_data_into_batches
from src.fedavg.server import FedAvgServer
# from src.fedprox.server import FedProxServer # Uncomment if you need FedProx
from src.flirt.server import FlirtServer
from src.utils.model_utils import report_metric_summary # For aggregated metrics display


def main(config_path):
    print(f"Loading configuration from: {config_path}")
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    model_name = config.get('model_name', 'Flirt')
    dataset_type = config.get('dataset_type', 'HSP')
    print(f"Running Federated Learning with {model_name} model for {dataset_type} dataset...")

    # Set random seeds for reproducibility
    seed = config.get('random_state', 38)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # 1. Load and Split Data
    train_1_df, _, train_2_df, _, global_test_data_df = load_and_split_data(
        node1_path=config['dataset_node1_path'],
        node2_path=config['dataset_node2_path'],
        test_size=config['test_size'],
        random_state=config['random_state']
    )

    input_dim = None # For Flirt, input_dim is not directly used for NN architecture
    feature_names = None
    target_label_names = None

    if model_name in ["FedAvg", "FedProx"]:
        # Assuming last column is target, others are features
        input_dim = global_test_data_df.shape[1] - 1
        # For NN models, feature_names and target_label_names are often handled differently
        # or not explicitly passed at this high level.
        # If your NN needs feature_names, you would extract them here.
        # target_label_names are usually inferred or set during model compilation.
        pass # No change needed for input_dim
    elif model_name == "Flirt":
        # Normalize labels in the DataFrame itself (last column)
        feature_names = list(global_test_data_df.columns[:-1])
    
        labels = global_test_data_df.iloc[:, -1]
        # Coerce numeric-looking strings like '0.0' -> 0
        labels = labels.astype(float).astype(int)
        global_test_data_df.iloc[:, -1] = labels  # write back normalized labels
    
        # Now target labels are clean ints; expose as strings e.g. ['0','1']
        target_label_names = sorted(labels.unique().astype(str).tolist())
    
        print(f"Feature Names: {feature_names}")
        print(f"Target Label Names: {target_label_names}")
    else:
        raise ValueError(f"Unknown model name: {model_name}. Please check 'model_name' in config.")

    # Split training data into client batches
    train_1_batches = split_data_into_batches(train_1_df, config['client_sample_size'])
    train_2_batches = split_data_into_batches(train_2_df, config['client_sample_size'])
    print(f"Prepared {len(train_1_batches)} batches for Group 1 and {len(train_2_batches)} batches for Group 2.")


    # 2. Initialize and Run Server
    server = None
    if model_name == "FedAvg":
        server = FedAvgServer(input_dim=input_dim, global_test_data=global_test_data_df, config=config)
        server.run_federated_training(train_1_batches, train_2_batches)
        print("FedAvg is not configured in this notebook. Please uncomment relevant lines.")
    elif model_name == "FedProx":
        server = FedProxServer(input_dim=input_dim, global_test_data=global_test_data_df, config=config)
        server.run_federated_training(train_1_batches, train_2_batches)
        print("FedProx is not configured in this notebook. Please uncomment relevant lines.")
    
    elif model_name == "Flirt":
        server = FlirtServer(
            config=config,
            global_test_data=global_test_data_df,
            feature_names=feature_names,
            target_label_names=target_label_names
        )
        server.run_federated_training(train_1_batches, train_2_batches)
    else:
        raise ValueError(f"Unknown model name: {model_name}")

    # 3. Report and Save Results (moved to server.run_federated_training for FlirtServer)
    # The FlirtServer.run_federated_training already saves results and prints final metrics.
    # We can still call the print_final_metrics and get_model_summary if desired
    # to display them again, but saving is handled internally.
    if server:
        server.print_final_metrics()
        if model_name in ["FedAvg", "FedProx"]:
            print(f"Global Model Parameters: {model_summary['total_parameters']}")
            print(f"Global Model Memory (bytes): {model_summary['memory_bytes']}")
    else:
        print("Server not initialized, skipping final reporting.")


# This block ensures the main function is called correctly whether run as a script or in Jupyter
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Federated Learning experiments.")
    parser.add_argument('--config', type=str, default='configs/flirt_config_vp.yaml',
                        help='Path to the configuration YAML file.')

    # Robust detection of Jupyter/IPython environment
    try:
        get_ipython  # type: ignore
        IN_JUPYTER = True
    except NameError:
        IN_JUPYTER = False

    if IN_JUPYTER:
        # Use parse_known_args in notebooks to ignore extra args injected by the environment
        args, unknown = parser.parse_known_args()
        if unknown:
            print(f"Ignored unknown arguments from Jupyter: {unknown}")
        main(args.config)
    else:
        # Standard command-line argument parsing for scripts
        args = parser.parse_args()
        main(args.config)

Ignored unknown arguments from Jupyter: ['--f=c:\\Users\\ahmad\\AppData\\Roaming\\jupyter\\runtime\\kernel-v3a82a5e4357a6ce6acd7a16165196db7fd80e08db.json']
Loading configuration from: configs/flirt_config_vp.yaml
Running Federated Learning with Flirt model for VP dataset...
Feature Names: ['N', 'F0', 'm', 'd_ms', 'd0', 'v0', 'prob']
Target Label Names: ['0', '1']
Prepared 63 batches for Group 1 and 63 batches for Group 2.
Total global test samples: 1328
Prepared 63 batches for Group1 and 63 for Group2 (total 126)
Round 1/250
  sampled G1:3 G2:3
Round 2/250
  sampled G1:3 G2:3
Round 3/250
  sampled G1:3 G2:3
Round 4/250
  sampled G1:3 G2:3
Round 5/250
  sampled G1:3 G2:3
Round 6/250
  sampled G1:3 G2:3
Round 7/250
  sampled G1:3 G2:3
Round 8/250
  sampled G1:3 G2:3
Round 9/250
  sampled G1:3 G2:3
Round 10/250
  sampled G1:3 G2:3
Round 11/250
  sampled G1:3 G2:3
Round 12/250
  sampled G1:3 G2:3
Round 13/250
  sampled G1:3 G2:3
Round 14/250
  sampled G1:3 G2:3
Round 15/250
  sampled G1:3