In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

##> import libraries
import sys
from pathlib import Path
import random
import time
from itertools import product
from typing import OrderedDict


root_dir = Path.cwd().resolve().parent
if root_dir.exists():
    sys.path.append(str(root_dir))
else:
    raise FileNotFoundError('Root directory not found')

#> import flower
import flwr as fl
from flwr.common import Context
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, DirichletPartitioner
from torch.utils.data import DataLoader
from datasets import Dataset


#> import custom libraries
from src.load import load_df_to_dataset
from src.EAE import EvidentialTransformerDenoiseAutoEncoder, evidential_regression
from src.client import train_and_evaluate_local, evaluate_saved_model
from src.datasets import TrajectoryDataset, clean_outliers_by_quantile, generate_ood_data
from src.plot import plot_tsne_with_uncertainty, visualize_mean_features

#> torch libraries
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import numpy as np
import statsmodels.api as sm
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score, roc_curve

#> Plot
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots  # https://github.com/garrettj403/SciencePlots?tab=readme-ov-file
# plt.style.use(['science', 'grid', 'notebook', 'ieee'])  # , 'ieee'


# %matplotlib inline
# %matplotlib widget


In [None]:
 # Define the dataset catalog
assets_dir = root_dir.parents[3] / 'aistraj' / 'bin'/ 'tvt_assets'
assets_dir = assets_dir.resolve()
print(f"Assets Directory: {assets_dir}")
if not assets_dir.exists():
    raise FileNotFoundError('Assets directory not found')
    
saved_model_dir = root_dir / 'models'
saved_model_dir = saved_model_dir.resolve()
print(f"Assets Directory: {saved_model_dir}")
if not saved_model_dir.exists():
    raise FileNotFoundError('Model directory not found')

In [None]:
 # setup_environment()
if multiprocessing.get_start_method(allow_none=True) != "spawn":
    try:
        multiprocessing.set_start_method("spawn", force=True)
    except RuntimeError as e:
        print(f"Warning: {e}")

# Define the dataset catalog
assets_dir = Path("/data1/aistraj/bin/tvt_assets").resolve()
print(f"Assets Directory: {assets_dir}")
if not assets_dir.exists():
    raise FileNotFoundError('Assets directory not found')

# Set the working directory to the 'src' directory, which contains only the code.
code_dir = root_dir / 'src'
code_dir = code_dir.resolve()
print(f"Code Directory: {code_dir}")
if not code_dir.exists():
    raise FileNotFoundError('Code directory not found')

excludes = ["data", "*.pyc", "__pycache__"
]

ray_init_args = {
    "runtime_env": {
        #"working_dir": str(code_dir),
        "py_modules": [str(code_dir)],
        "excludes": [str(code_dir / file) for file in excludes]
    },
    "include_dashboard": False,
    #"num_cpus": 4,
    # "local_mode": True
}

num_clients = 4

# config = {
#     "lambda_reg": 1,     
#     "num_epochs": 1,        
#     "offset": 2.5,       
# }



In [None]:
import numpy as np
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score, roc_curve

def calculate_ood_metrics(id_scores, ood_scores, threshold_method='percentile', percentile=95, k=1.0, reduce_method='mean'):
    # 1. Ensure input data is a NumPy array
    id_scores = np.array(id_scores)
    ood_scores = np.array(ood_scores)
    
    # 2. Dimensionality reduction if the input is 2D (batch_size, latent_dim)
    if id_scores.ndim > 1:
        if reduce_method == 'mean':
            id_scores = np.mean(id_scores, axis=1)
            ood_scores = np.mean(ood_scores, axis=1)
        elif reduce_method == 'max':
            id_scores = np.max(id_scores, axis=1)
            ood_scores = np.max(ood_scores, axis=1)
        elif reduce_method == 'l2':
            id_scores = np.linalg.norm(id_scores, axis=1)
            ood_scores = np.linalg.norm(ood_scores, axis=1)
        else:
            raise ValueError("Invalid reduce_method. Available options are 'mean', 'max', 'l2'.")

    # 3. Calculate the threshold
    if threshold_method == 'percentile':
        threshold = np.percentile(id_scores, percentile)
    elif threshold_method == 'mean_std':
        threshold = np.mean(id_scores) + k * np.std(id_scores)
    else:
        raise ValueError("threshold_method must be 'percentile' or 'mean_std'")
    
    # 4. Concatenate ID and OOD scores
    all_scores = np.concatenate([id_scores, ood_scores], axis=0)
    
    # 5. Create labels (ID is 0, OOD is 1)
    labels_id = np.zeros(len(id_scores))  # ID labels
    labels_ood = np.ones(len(ood_scores)) # OOD labels
    all_labels = np.concatenate([labels_id, labels_ood], axis=0)
    
    # 6. Generate predictions based on threshold
    predictions = (all_scores > threshold).astype(int)
    
    # 7. Calculate metrics
    # F1 Score
    f1 = f1_score(all_labels, predictions)
    # AUROC
    auroc = roc_auc_score(all_labels, all_scores)
    # AUPR
    aupr = average_precision_score(all_labels, all_scores)
    
    # Detection Error
    fpr, tpr, roc_thresholds = roc_curve(all_labels, all_scores)
    detection_errors = 0.5 * (fpr + (1 - tpr))
    detection_error = np.min(detection_errors)

    return f1, auroc, aupr, detection_error, threshold


In [None]:
sequential_model_save_path = saved_model_dir + 'eae_model_sequential_lambda05_960.pth'
global_model_save_path = saved_model_dir + 'feae_model_global_lambda05_random_960.pth'
local_model_save_path = saved_model_dir + 'eae_model_qt_lambda05_960.pth'

In [None]:
def load_datasets_eval(assets_dir, seq_len=960, batch_size=32):
    
     # train dataset
    train_pickle_path_extend = assets_dir / 'extended' / 'cleaned_extended_train_df.parquet'
    train_df_extend = load_df_to_dataset(train_pickle_path_extend).data

    # validation dataset
    validate_pickle_path_extend = assets_dir / 'extended' / 'cleaned_extended_validate_df.parquet'
    validate_df_extend = load_df_to_dataset(validate_pickle_path_extend).data
    ood_df = generate_ood_data(validate_df_extend, ood_mean=10, ood_std=3)
    #print (ood_df.shape)

    # Define the list of features to discard
    drop_features_list = ['epoch', 'datetime', 'obj_id', 'traj_id', 'stopped', 'curv', 'abs_ccs']
    columns_to_clean = ['speed_c', 'lon', 'lat']  # Specify columns to clean
    
    cleaned_train_data = clean_outliers_by_quantile(train_df_extend, columns_to_clean, remove_na=False)
    cleaned_val_data = clean_outliers_by_quantile(validate_df_extend, columns_to_clean, remove_na=False)
    
    df_extend = pd.concat([cleaned_train_data, cleaned_val_data])
    df_extend = df_extend.sort_index()
    
    val_dataset_traj = TrajectoryDataset(
        cleaned_val_data,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        scaler_method='QuantileTransformer',
        filter_less_seq_len=6
        #categorical_features=['season']
    )
    
    val_ood_dataset_traj = TrajectoryDataset(
        ood_df,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        filter_percent = None,
        scaler = None,
        filter_less_seq_len = None,
        scaler_method = 'No_Scaler'
    )    
    #print (val_ood_dataset_traj.inputs)

    val_dataloader_traj = DataLoader(
        val_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=False
    )
    
    val_ood_dataloader_traj = DataLoader(
        val_ood_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=False
    )

    return val_dataloader_traj, val_ood_dataloader_traj, val_dataset_traj.n_features, val_dataset_traj

In [None]:
# Load Dataset
val_dataloader_traj, val_ood_dataloader_traj, input_dim, dataset_traj = load_datasets_eval(assets_dir)

In [None]:
val_ood_dataloader_traj.dataset.inputs.shape

In [None]:
val_dataloader_traj.dataset.inputs.shape

## Federated

In [None]:
global_model = EvidentialTransformerDenoiseAutoEncoder(
    input_dim=input_dim,
    d_model=8,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1
)

In [None]:
val_loss_g, val_aleatoric_uncertainties_g, val_epistemic_uncertainties_g, avg_aleatoric_uncertainty_g, avg_epistemic_uncertainty_g, latent_representations_eval_g, recon_error_g = evaluate_saved_model(
    model_class=global_model, 
    model_path=global_model_save_path, 
    criterion=evidential_regression, 
    val_dataloader=val_dataloader_traj, 
    lambda_reg=0.5, 
    offset=2.5, 
    device='cuda', 
    return_latent=True
)


In [None]:
print("Average:", sum(recon_error_g) / len(recon_error_g))

In [None]:
val_ood_loss_g, val_ood_aleatoric_uncertainties_g, val_ood_epistemic_uncertainties_g, avg_ood_aleatoric_uncertainty_g, avg_ood_epistemic_uncertainty_g, latent_ood_representations_eval_g, recon_ood_error_g = evaluate_saved_model(
    model_class=global_model, 
    model_path=global_model_save_path, 
    criterion=evidential_regression, 
    val_dataloader=val_ood_dataloader_traj, 
    lambda_reg=0.5, 
    offset=2.5, 
    device='cuda', 
    return_latent=True
)

In [None]:
print("Average:", sum(recon_ood_error_g) / len(recon_ood_error_g))

In [None]:
combined_latent_representations_g = np.concatenate([latent_representations_eval_g, latent_ood_representations_eval_g], axis = 0)
combined_val_epistemic_uncertainties_g = np.concatenate([val_epistemic_uncertainties_g, val_ood_epistemic_uncertainties_g], axis = 0)
combined_val_aleatoric_uncertainties_g = np.concatenate([val_aleatoric_uncertainties_g, val_ood_aleatoric_uncertainties_g], axis = 0)
combined_recon_error_g = recon_error_g + recon_ood_error_g

In [None]:
ood_labels_g = [0] * len(latent_representations_eval_g) + [1] * len(latent_ood_representations_eval_g)
plot_tsne_with_uncertainty(combined_latent_representations_g, ood_labels_g, uncertainty_type='ood label')

In [None]:
percentile_98_g = np.percentile(recon_error_g, 95)
print(percentile_98_g)
plot_tsne_with_uncertainty(latent_representations_eval_g, recon_error_g, uncertainty_type='recon_error', threshold = percentile_98_g)


In [None]:
percentile_98_g = np.percentile(combined_recon_error_g, 98)
print(percentile_98_g)
plot_tsne_with_uncertainty(combined_latent_representations_g, combined_recon_error_g, uncertainty_type='recon_error', threshold = percentile_98_g)


In [None]:
percentile_98_g_uncertainty = np.percentile(combined_val_epistemic_uncertainties_g, 95)
print(percentile_98_g_uncertainty)
plot_tsne_with_uncertainty(combined_latent_representations_g, combined_val_epistemic_uncertainties_g, uncertainty_type='val_epistemic_uncertainties', threshold = percentile_98_g_uncertainty)

In [None]:
plot_tsne_with_uncertainty(latent_representations_eval_g, val_epistemic_uncertainties_g, uncertainty_type='val_epistemic_uncertainties without ood', threshold = percentile_98_g_uncertainty)

In [None]:
percentile_98_g_uncertainty_a = np.percentile(combined_val_aleatoric_uncertainties_g, 95)
print(percentile_98_g_uncertainty_a)
plot_tsne_with_uncertainty(combined_latent_representations_g, combined_val_aleatoric_uncertainties_g, uncertainty_type='val_aleatoric_uncertainties', threshold = percentile_98_g_uncertainty_a)

In [None]:
plot_tsne_with_uncertainty(latent_representations_eval_g, val_aleatoric_uncertainties_g, uncertainty_type='val_aleatoric_uncertainties without ood', threshold = percentile_98_g_uncertainty)

In [None]:
f1_g, auroc_g, aupr_g, detection_error_g, threshold_g = calculate_ood_metrics(recon_error_g, recon_ood_error_g, threshold_method='percentile', percentile=95)
print(f"Reconstruction Error F1 score (FEAE): {f1_g:.4f}, AUROC: {auroc_g}, AUPR: {aupr_g}, Detection Error: {detection_error_g}, Threshold: {threshold_g:.4f}")

In [None]:
f1_g_eu, auroc_eu, aupr_eu, detection_error_eu, threshold_g_eu = calculate_ood_metrics(val_epistemic_uncertainties_g, val_ood_epistemic_uncertainties_g, threshold_method='percentile', percentile=95)
print(f"Epistemic Uncertainty F1 score (FEAE): {f1_g_eu:.4f}, AUROC: {auroc_eu}, AUPR: {aupr_eu}, Detection Error: {detection_error_eu}, Threshold: {threshold_g_eu:.4f}")

In [None]:
f1_g_au, auroc_au, aupr_au, detection_error_au, threshold_g_au = calculate_ood_metrics(val_aleatoric_uncertainties_g, val_ood_aleatoric_uncertainties_g, threshold_method='percentile', percentile=95)
print(f"Aleatoric Uncertainty F1 score (FEAE): {f1_g_au:.4f}, AUROC: {auroc_au}, AUPR: {aupr_au}, Detection Error: {detection_error_au}, Threshold: {threshold_g_au:.4f}")

### Analysis

In [None]:
#for key in val_dataset_traj.labels.keys():
# for key in ['epoch', 'stopped', 'cog_c', 'aad', 'rot_c', 'speed_c', 'distance_c',
#        'acc_c', 'cdd', 'dir_ccs', 'dist_ww', 'dist_ra',
#        'dist_cl', 'dist_ma', 'traj_id', 'lon', 'lat', 'obj_id', 'datetime',
#        'season', 'part_of_day', 'month_sin', 'month_cos', 'hour_sin',
#        'hour_cos']:
# for key in ['cog_c', 'aad', 'rot_c', 'speed_c', 'distance_c',
#        'acc_c', 'cdd', 'dir_ccs', 'dist_ww', 'dist_ra',
#        'dist_cl', 'dist_ma', 'traj_id', 'lon', 'lat',
#        'season', 'part_of_day', 'month_sin', 'month_cos', 'hour_sin',
#        'hour_cos']:
#     plot_tsne_with_uncertainty(latent_representations_eval_g, dataset_traj.labels[key], uncertainty_type=key)

## Local

In [None]:
local_model = EvidentialTransformerDenoiseAutoEncoder(
    input_dim=input_dim,
    d_model=8,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1
)

In [None]:
val_loss_l, val_aleatoric_uncertainties_l, val_epistemic_uncertainties_l, avg_aleatoric_uncertainty_l, avg_epistemic_uncertainty_l, latent_representations_eval_l, recon_error_l = evaluate_saved_model(
    model_class=local_model, 
    model_path=local_model_save_path, 
    criterion=evidential_regression, 
    val_dataloader=val_dataloader_traj, 
    lambda_reg=0.5, 
    offset=2.5, 
    device='cuda', 
    return_latent=True
)


In [None]:
print("Average:", sum(recon_error_l) / len(recon_error_l))

In [None]:
val_ood_loss_l, val_ood_aleatoric_uncertainties_l, val_ood_epistemic_uncertainties_l, avg_ood_aleatoric_uncertainty_l, avg_ood_epistemic_uncertainty_l, latent_ood_representations_eval_l, recon_ood_error_l = evaluate_saved_model(
    model_class=local_model, 
    model_path=local_model_save_path, 
    criterion=evidential_regression, 
    val_dataloader=val_ood_dataloader_traj, 
    lambda_reg=0.5, 
    offset=2.5, 
    device='cuda', 
    return_latent=True
)

In [None]:
combined_latent_representations_l = np.concatenate([latent_representations_eval_l, latent_ood_representations_eval_l], axis = 0)
combined_val_epistemic_uncertainties_l = np.concatenate([val_epistemic_uncertainties_l, val_ood_epistemic_uncertainties_l], axis = 0)
combined_val_aleatoric_uncertainties_l = np.concatenate([val_aleatoric_uncertainties_l, val_ood_aleatoric_uncertainties_l], axis = 0)
combined_recon_error_l = recon_error_l + recon_ood_error_l

In [None]:
ood_labels_l = [0] * len(latent_representations_eval_l) + [1] * len(latent_ood_representations_eval_l)
plot_tsne_with_uncertainty(combined_latent_representations_l, ood_labels_l, uncertainty_type='ood label')

In [None]:
percentile_98_l = np.percentile(combined_recon_error_l, 95)
print(percentile_98_l)

plot_tsne_with_uncertainty(combined_latent_representations_l, combined_recon_error_l, uncertainty_type='recon_error', threshold = percentile_98_l)

In [None]:
percentile_98_l_uncertainty = np.percentile(combined_val_epistemic_uncertainties_l, 95)
print(percentile_98_l_uncertainty)
plot_tsne_with_uncertainty(combined_latent_representations_l, combined_val_epistemic_uncertainties_l, uncertainty_type='val_epistemic_uncertainties', threshold = percentile_98_l_uncertainty)

In [None]:
plot_tsne_with_uncertainty(latent_representations_eval_l, val_epistemic_uncertainties_l, uncertainty_type='val_epistemic_uncertainties without ood', threshold = percentile_98_l_uncertainty)

In [None]:
percentile_98_l_uncertainty_a = np.percentile(combined_val_aleatoric_uncertainties_l, 95)
print(percentile_98_l_uncertainty_a)
plot_tsne_with_uncertainty(combined_latent_representations_l, combined_val_aleatoric_uncertainties_l, uncertainty_type='val_aleatoric_uncertainties', threshold = percentile_98_l_uncertainty_a)

In [None]:
plot_tsne_with_uncertainty(latent_representations_eval_l, val_aleatoric_uncertainties_l, uncertainty_type='val_aleatoric_uncertainties without ood', threshold = percentile_98_l_uncertainty)

In [None]:
f1_l, auroc_l, aupr_l, detection_error_l, threshold_l = calculate_ood_metrics(recon_error_l, recon_ood_error_l, threshold_method='percentile', percentile=95)
print(f"Reconstruction Error F1 score (FEAE): {f1_l:.4f}, AUROC: {auroc_l}, AUPR: {aupr_l}, Detection Error: {detection_error_l}, Threshold: {threshold_l:.4f}")

f1_l_eu, auroc_l_eu, aupr_l_eu, detection_error_l_eu, threshold_l_eu = calculate_ood_metrics(val_epistemic_uncertainties_l, val_ood_epistemic_uncertainties_l, threshold_method='percentile', percentile=94.5)
print(f"Epistemic Uncertainty F1 score (FEAE): {f1_l_eu:.4f}, AUROC: {auroc_l_eu}, AUPR: {aupr_l_eu}, Detection Error: {detection_error_l_eu}, Threshold: {threshold_l_eu:.4f}")

f1_l_au, auroc_l_au, aupr_l_au, detection_error_l_au, threshold_l_au = calculate_ood_metrics(val_aleatoric_uncertainties_l, val_ood_aleatoric_uncertainties_l, threshold_method='percentile', percentile=75)
print(f"Aleatoric Uncertainty F1 score (FEAE): {f1_l_au:.4f}, AUROC: {auroc_l_au}, AUPR: {aupr_l_au}, Detection Error: {detection_error_l_au}, Threshold: {threshold_l_au:.4f}")

## Sequential

In [None]:
sequential_model = EvidentialTransformerDenoiseAutoEncoder(
    input_dim=input_dim,
    d_model=8,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1
)

In [None]:
val_loss_s, val_aleatoric_uncertainties_s, val_epistemic_uncertainties_s, avg_aleatoric_uncertainty_s, avg_epistemic_uncertainty_s, latent_representations_eval_s, recon_error_s = evaluate_saved_model(
    model_class=sequential_model, 
    model_path=sequential_model_save_path, 
    criterion=evidential_regression, 
    val_dataloader=val_dataloader_traj, 
    lambda_reg=0.1, 
    offset=2.5, 
    device='cuda', 
    return_latent=True
)


In [None]:
val_ood_loss_s, val_ood_aleatoric_uncertainties_s, val_ood_epistemic_uncertainties_s, avg_ood_aleatoric_uncertainty_s, avg_ood_epistemic_uncertainty_s, latent_ood_representations_eval_s, recon_ood_error_s = evaluate_saved_model(
    model_class=sequential_model, 
    model_path=sequential_model_save_path, 
    criterion=evidential_regression, 
    val_dataloader=val_ood_dataloader_traj, 
    lambda_reg=0.1, 
    offset=2.5, 
    device='cuda', 
    return_latent=True
)

In [None]:
combined_latent_representations_s = np.concatenate([latent_representations_eval_s, latent_ood_representations_eval_s], axis = 0)
combined_val_epistemic_uncertainties_s = np.concatenate([val_epistemic_uncertainties_s, val_ood_epistemic_uncertainties_s], axis = 0)
combined_val_aleatoric_uncertainties_s = np.concatenate([val_aleatoric_uncertainties_s, val_ood_aleatoric_uncertainties_s], axis = 0)
combined_recon_error_s = recon_error_s + recon_ood_error_s

In [None]:
ood_labels = [0] * len(latent_representations_eval_s) + [1] * len(latent_ood_representations_eval_s)
plot_tsne_with_uncertainty(combined_latent_representations_s, ood_labels, uncertainty_type='ood label')

In [None]:
percentile_98_s = np.percentile(combined_recon_error_s, 98)
print(percentile_98_s)
plot_tsne_with_uncertainty(combined_latent_representations_s, combined_recon_error_s, uncertainty_type='recon_error', threshold = percentile_98_s)

In [None]:
percentile_98_s_uncertainty = np.percentile(combined_val_epistemic_uncertainties_s, 98)
print(percentile_98_s_uncertainty)
plot_tsne_with_uncertainty(combined_latent_representations_s, combined_val_epistemic_uncertainties_s, uncertainty_type='val_epistemic_uncertainties', threshold = percentile_98_s_uncertainty)

In [None]:
percentile_98_s_uncertainty_a = np.percentile(combined_val_aleatoric_uncertainties_s, 98)
print(percentile_98_s_uncertainty_a)
plot_tsne_with_uncertainty(combined_latent_representations_s, combined_val_aleatoric_uncertainties_s, uncertainty_type='val_aleatoric_uncertainties', threshold = percentile_98_s_uncertainty_a)