In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import os
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

##> import libraries
import sys
from pathlib import Path
import random
import time
from itertools import product
from typing import OrderedDict


root_dir = Path.cwd().resolve().parent
if root_dir.exists():
    sys.path.append(str(root_dir))
else:
    raise FileNotFoundError('Root directory not found')

#> import flower
import flwr as fl
from flwr.common import Context
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, DirichletPartitioner, NaturalIdPartitioner
from torch.utils.data import DataLoader
from datasets import Dataset


#> import custom libraries
from src.load import load_df_to_dataset
from src.EAE import EvidentialTransformerDenoiseAutoEncoder, evidential_regression
from src.client import EAEClient, evaluate_saved_model
from src.datasets import TrajectoryDataset, clean_outliers_by_quantile
from src.plot import plot_tsne_with_uncertainty, plot_uncertainty

#> torch libraries
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import numpy as np
import statsmodels.api as sm

#> Plot
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots  # https://github.com/garrettj403/SciencePlots?tab=readme-ov-file
plt.style.use(['science', 'grid', 'notebook', 'ieee'])  # , 'ieee'


# %matplotlib inline
# %matplotlib widget


In [None]:
 # Define the dataset catalog
assets_dir = root_dir.parents[3] / 'aistraj' / 'bin'/ 'tvt_assets'
assets_dir = assets_dir.resolve()
print(f"Assets Directory: {assets_dir}")
if not assets_dir.exists():
    raise FileNotFoundError('Assets directory not found')
    
saved_model_dir = root_dir / 'models'
saved_model_dir = saved_model_dir.resolve()
print(f"Assets Directory: {saved_model_dir}")
if not saved_model_dir.exists():
    raise FileNotFoundError('Model directory not found')

In [None]:
 # setup_environment()
if multiprocessing.get_start_method(allow_none=True) != "spawn":
    try:
        multiprocessing.set_start_method("spawn", force=True)
    except RuntimeError as e:
        print(f"Warning: {e}")

# Define the dataset catalog
assets_dir = Path("/data1/aistraj/bin/tvt_assets").resolve()
print(f"Assets Directory: {assets_dir}")
if not assets_dir.exists():
    raise FileNotFoundError('Assets directory not found')

# Set the working directory to the 'src' directory, which contains only the code.
code_dir = root_dir / 'src'
code_dir = code_dir.resolve()
print(f"Code Directory: {code_dir}")
if not code_dir.exists():
    raise FileNotFoundError('Code directory not found')

excludes = ["data", "*.pyc", "__pycache__"
]

ray_init_args = {
    "runtime_env": {
        #"working_dir": str(code_dir),
        "py_modules": [str(code_dir)],
        "excludes": [str(code_dir / file) for file in excludes]
    },
    "include_dashboard": False,
    #"num_cpus": 4,
    # "local_mode": True
}

num_clients = 4

config = {
    "lambda_reg": 0.5,     
    "num_epochs": 1,        
    "offset": 2.5,       
}

global_model_save_path = '/data1/sgao/repos/CogSigma/oNSA/models/feae_model_global_lambda05_random_960_20e.pth'

In [4]:
def aggregate_metrics(metrics):
    """Aggregate evaluation metrics from multiple clients."""
    if not metrics:
        return {}
    
    # Print evaluation metrics for each client for debugging purposes
    for idx, m in enumerate(metrics):
        print(f"Client {idx}: {m}")
    
    # Extract the dictionary part returned by all clients and verify its structure
    metrics_dicts = []
    for m in metrics:
        if isinstance(m, tuple) and len(m) == 2 and isinstance(m[1], dict):
            metrics_dicts.append(m[1])
        else:
            print(f"Unexpected metrics format: {m}")  

    if not metrics_dicts:
        print("No valid metrics to aggregate.")
        return {}

    # Aggregate metrics_dicts that match the format of the
    aggregated = {}
    for key in metrics_dicts[0].keys():
        aggregated[key] = sum(m[key] for m in metrics_dicts) / len(metrics_dicts)
    print("All metrics received from clients:", metrics)
    return aggregated


In [5]:
def load_datasets_eval(assets_dir, seq_len=960, batch_size=32):

    # validation dataset
    validate_pickle_path_extend = assets_dir / 'extended' / 'cleaned_extended_validate_df.parquet'
    validate_df_extend = load_df_to_dataset(validate_pickle_path_extend).data

    # Define the list of features to discard
    drop_features_list = ['epoch', 'datetime', 'obj_id', 'traj_id', 'stopped', 'curv', 'abs_ccs']
    columns_to_clean = ['speed_c', 'lon', 'lat']  # Specify columns to clean
    
    validate_df_extend = clean_outliers_by_quantile(validate_df_extend, columns_to_clean, remove_na=False)
    
    val_dataset_traj = TrajectoryDataset(
        validate_df_extend,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        scaler_method='QuantileTransformer',
        filter_less_seq_len = seq_len
    )

    val_dataloader_traj = DataLoader(
        val_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=False
    )

    return val_dataloader_traj

In [6]:
class CustomFederatedDataset(FederatedDataset):
    def _prepare_dataset(self) -> None:
        """Override the original method and use the local dataset directly to avoid loading from the Hugging Face Hub."""
        self._dataset_prepared = True

def load_datasets_fl_with_partitioner(
    assets_dir,
    num_clients,
    seq_len=960,
    batch_size=32,
    partitioner_type="iid",
    alpha=0.5, # Smoothing parameters for the Dirichlet distribution,
    partition_column=None,
    **partitioner_kwargs,
):
    # Load local training and validation datasets
    train_pickle_path_extend = assets_dir / "extended" / "cleaned_extended_train_df.parquet"
    train_df_extend = load_df_to_dataset(train_pickle_path_extend).data

    validate_pickle_path_extend = assets_dir / "extended" / "cleaned_extended_validate_df.parquet"
    validate_df_extend = load_df_to_dataset(validate_pickle_path_extend).data

    columns_to_clean = ['speed_c', 'lon', 'lat']  # Specify columns to clean
    train_df_extend = clean_outliers_by_quantile(train_df_extend, columns_to_clean, remove_na=False)
    validate_df_extend = clean_outliers_by_quantile(validate_df_extend, columns_to_clean, remove_na=False)
    
    print("After correction:")
    print("Unique 'season' values in training dataset:", train_df_extend['season'].unique())
    print("Unique 'season' values in validation dataset:", validate_df_extend['season'].unique())
    
    #  Convert pandas.DataFrame to datasets.
    train_dataset = Dataset.from_pandas(train_df_extend, preserve_index=False)
    val_dataset = Dataset.from_pandas(validate_df_extend, preserve_index=False)

    # Choose the partitioner
    if partitioner_type == "iid":
        partitioner_class = IidPartitioner(num_partitions=num_clients, **partitioner_kwargs)
        
    elif partitioner_type == "naturalidpartitioner":
        if partition_column is None:
            raise ValueError("partition_column must be specified when using DirichletPartitioner.")
        partitioner_class = NaturalIdPartitioner(
        #num_partitions=num_clients,
        partition_by=partition_column,
        **partitioner_kwargs,
        )
    elif partitioner_type == "dirichlet":
        if partition_column is None:
            raise ValueError("partition_column must be specified when using DirichletPartitioner.")
        partitioner_class = DirichletPartitioner(
            num_partitions=num_clients,
            alpha=alpha,
            partition_by=partition_column,
            **partitioner_kwargs,
        )
    else:
        raise ValueError(f"Unknown partitioner type: {partitioner_type}")

    # Initialize custom FederatedDataset
    fds_train = CustomFederatedDataset(
        dataset="train",
        partitioners={"train": partitioner_class},
    )
    fds_val = CustomFederatedDataset(
        dataset="val",
        partitioners={"val": partitioner_class},
    )

    # Manually assign locally loaded datasets.Dataset
    fds_train._dataset = {"train": train_dataset}
    fds_val._dataset = {"val": val_dataset}

    # Define features to be discarded
    drop_features_list = ["epoch", "datetime", "obj_id", "traj_id", "stopped", "curv", "abs_ccs"]

    # Create per-client data loaders
    train_dataloaders, val_dataloaders = [], []
    n_features = None

    for client_id in range(num_clients):
        train_partition = fds_train.load_partition(client_id, split="train").to_pandas()
        val_partition = fds_val.load_partition(client_id, split="val").to_pandas()
        train_seasons = set(train_partition['season'].unique())
        print(f"Client {client_id + 1} season values: {train_seasons}")

        train_dataset_traj = TrajectoryDataset(
            train_partition,
            seq_len=seq_len,
            mode="ae",
            drop_features_list=drop_features_list,
            scaler_method='QuantileTransformer',
            filter_less_seq_len = seq_len
        )
        val_dataset_traj = TrajectoryDataset(
            val_partition,
            seq_len=seq_len,
            mode="ae",
            drop_features_list=drop_features_list,
            scaler_method='QuantileTransformer',
            filter_less_seq_len = seq_len
        )

        train_dataloader = DataLoader(
            train_dataset_traj,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=False,
        )
        val_dataloader = DataLoader(
            val_dataset_traj,
            batch_size=batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=False,
        )

        train_dataloaders.append(train_dataloader)
        val_dataloaders.append(val_dataloader)

        if n_features is None:
            n_features = train_dataset_traj.n_features

    return train_dataloaders, val_dataloaders, n_features

In [None]:
train_dataloaders, val_dataloaders, input_dim = load_datasets_fl_with_partitioner(
    assets_dir=assets_dir,
    num_clients=num_clients,
    seq_len=960,
    batch_size=32,
    partitioner_type="iid", # dirichlet, naturalidpartitioner, iid
    #alpha=0.5,
    # partition_column='season'
)

In [None]:
for client_id, dataloader in enumerate(train_dataloaders):
    df = dataloader.dataset.dataframe
    unique_seasons = df['season'].unique()
    print(f"Client {client_id + 1} season values: {unique_seasons}")


In [None]:
all_seasons = ['Spring', 'Summer', 'Autumn', 'Winter']

client_season_distributions = []

for client_id, dataloader in enumerate(train_dataloaders):
    season_counts = []
    for sample in dataloader.dataset:
        inputs = sample['inputs']  # [seq_len, n_features]
        input_mask = sample['input_masks']  # [seq_len]

        valid_season_values = inputs[input_mask.bool(), -1].numpy()  
        season_counts.extend(valid_season_values.tolist())  
    
    season_distribution = pd.Series(season_counts).value_counts().sort_index()
    season_distribution = season_distribution.reindex(range(4), fill_value=0)  

    season_distribution.index = all_seasons
    client_season_distributions.append(season_distribution)

season_distribution_df = pd.DataFrame(client_season_distributions).T
season_distribution_df.columns = [f'Client {i+1}' for i in range(len(client_season_distributions))]
print(season_distribution_df)



In [None]:
season_distribution_df.T.plot(kind='bar', stacked=True, figsize=(12, 8), width=0.7, cmap='tab20')
plt.title("Season Distribution Across Clients", fontsize=25)
#plt.xlabel("Clients", fontsize=14)
#plt.ylabel("Count", fontsize=14)
plt.xticks(fontsize=25, rotation=0)  # Make the x-tick labels (Client1, Client2, etc.) larger
plt.yticks(fontsize=25)
plt.legend(title="Season", fontsize=18)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

In [15]:
def client_fn(client_id: int):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Select the corresponding client data from the list
    train_dataloader = train_dataloaders[client_id]
    val_dataloader = val_dataloaders[client_id]

    # Initialization Model
    model = EvidentialTransformerDenoiseAutoEncoder(
        input_dim=input_dim,
        d_model=8,
        nhead=4,
        num_encoder_layers=2,
        num_decoder_layers=2,
        dim_feedforward=32,
        max_seq_length=960,
        dropout_rate=0.1,
    ).to(device)

    # Define the loss function and optimizer
    criterion = evidential_regression
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Returns the client instance
    return EAEClient(
        cid=client_id,
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        config=config,
        save_model_path=None,
    ).to_client()

In [16]:
class SaveModelFedAvg(fl.server.strategy.FedAvg):
    def __init__(self, save_path, model_architecture, device, num_rounds, **kwargs):
        super().__init__(**kwargs)
        self.save_path = save_path
        self.model_architecture = model_architecture
        self.device = device
        self.num_rounds = num_rounds  

    def aggregate_fit(self, rnd, results, failures):
        # Calling the aggregation methods of the parent class
        aggregated_result = super().aggregate_fit(rnd, results, failures)
        if aggregated_result is not None:
            parameters_aggregated, metrics_aggregated = aggregated_result
            # Saving the global model in the last round
            if rnd == self.num_rounds:
                print(f"Saving global model at round {rnd}")
                self.save_model(parameters_aggregated)
        return aggregated_result

    def save_model(self, parameters):
        # Convert parameters to NumPy format
        params_ndarrays = fl.common.parameters_to_ndarrays(parameters)

        # Initialization Model
        model = self.model_architecture().to(self.device)

        # Setting Model Parameters
        params_dict = zip(model.state_dict().keys(), params_ndarrays)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict, strict=True)

        # Saving Models
        torch.save(model.state_dict(), self.save_path)
        print(f"Global model saved to {self.save_path}")


In [None]:
num_rounds = 20

strategy = SaveModelFedAvg(
    save_path=global_model_save_path,
    model_architecture=lambda: EvidentialTransformerDenoiseAutoEncoder(
        input_dim=input_dim,
        d_model=8,
        nhead=4,
        num_encoder_layers=2,
        num_decoder_layers=2,
        dim_feedforward=32,
        max_seq_length=960,
        dropout_rate=0.1,
    ),
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    num_rounds=num_rounds,  
    fraction_fit=1,
    fraction_evaluate=1,
    min_fit_clients=num_clients,
    min_evaluate_clients=num_clients,
    min_available_clients=num_clients,
    evaluate_metrics_aggregation_fn=aggregate_metrics,
)


history = fl.simulation.start_simulation(
    client_fn=lambda cid: client_fn(int(cid)),
    num_clients=num_clients,
    config=fl.server.ServerConfig(num_rounds=num_rounds),
    strategy=strategy,
    client_resources={"num_cpus": 1.0, "num_gpus": 0.25},
    ray_init_args=ray_init_args,
)

In [None]:
global_model = EvidentialTransformerDenoiseAutoEncoder(
    input_dim=input_dim,
    d_model=8,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1
)

In [None]:
# Load Dataset
val_dataloader_traj = load_datasets_eval(assets_dir)

In [None]:
val_loss, val_aleatoric_uncertainties, val_epistemic_uncertainties, avg_aleatoric_uncertainty, avg_epistemic_uncertainty, latent_representations_eval, recon_error = evaluate_saved_model(
    model_class=global_model, 
    model_path=global_model_save_path, 
    criterion=evidential_regression, 
    val_dataloader=val_dataloader_traj, 
    lambda_reg=1, 
    offset=2.5, 
    device='cuda', 
    return_latent=True
)


In [None]:
plot_tsne_with_uncertainty(latent_representations_eval, val_aleatoric_uncertainties, uncertainty_type='aleatoric')

In [None]:
plot_tsne_with_uncertainty(latent_representations_eval, val_epistemic_uncertainties, uncertainty_type='epistemic')