In [7]:
import warnings
from argparse import ArgumentParser

import pandas as pd
import pickle

from datasets import Priv_NAMES as DATASET_NAMES
from datasets import get_private_dataset
from models import get_all_models
from models import get_model
from utils.Server import train
from utils.Toolbox_analysis import create_latent_df, process_latent_df

warnings.simplefilter(action='ignore', category=FutureWarning)


def parse_args():
    parser = ArgumentParser(description='You Only Need Me', allow_abbrev=False)
    parser.add_argument('--device_id', type=int, default=0, help='The Device Id for Experiment')
    parser.add_argument('--experiment_id', type=str, default='Pipeline_Full_medium_E', help='Experiment identifier')
    parser.add_argument('--extra_coments', type=str, default='proto_month', help='Aditional info')
    parser.add_argument('--run_simulation', type=bool, default=False, help='The Device Id for Experiment')
    parser.add_argument('--detect_anomalies', type=bool, default=False)
    parser.add_argument('--generate_viz', type=bool, default=True, help='Creates and saves interactive visualizations')


    # Communication - epochs
    parser.add_argument('--communication_epoch', type=int, default=15,
                        help='The Communication Epoch in Federated Learning')
    parser.add_argument('--local_epoch', type=int, default=1, help='The Local Epoch for each Participant')

    # Participants info
    parser.add_argument('--parti_num', type=int, default=None, help='The Number for Participants. If "None" will be setted as the sum of values described in --domain')
    parser.add_argument('--online_ratio', type=float, default=1, help='The Ratio for Online Clients')
    parser.add_argument('--tgt_district', type=str, default='District_E', help='Target district name.')
    
    # Data parameter
    parser.add_argument('--dataset', type=str, default='fl_leaks', choices=DATASET_NAMES, help='Which scenario to perform experiments on.')
    parser.add_argument('--domains', type=dict, default={
                                                        'Graeme': 5,
                                                        # 'Balerma': 3,
                                                        },
                        help='Domains and respective number of participants.')

    ## Time series preprocessing
    parser.add_argument('--interval_agg', type=int, default=2 * 60 ** 2,
                        help='Agregation interval (seconds) of time series')
    parser.add_argument('--window_size', type=int, default=84, help='Rolling window length')

    # Model (AER) parameters
    parser.add_argument('--input_size', type=int, default=5, help='Number of sensors')  #TODO adaptar
    parser.add_argument('--output_size', type=int, default=5, help='Shape output - dense layer')
    parser.add_argument('--lstm_units', type=int, default=20,
                        help='Number of LSTM units (the latent space will have dimension 2 times bigger')
    

    # Federated parameters
    parser.add_argument('--model', type=str, default='fpl', help='Federated Model name.', choices=get_all_models()) #fedavg

    parser.add_argument('--structure', type=str, default='homogeneity')

    parser.add_argument('--pri_aug', type=str, default='weak',  # weak strong
                        help='Augmentation for Private Data')
    parser.add_argument('--learning_decay', type=bool, default=False, help='The Option for Learning Rate Decay')
    parser.add_argument('--averaging', type=str, default='weight', help='The Option for averaging strategy')

    parser.add_argument('--infoNCET', type=float, default=0.02, help='The InfoNCE temperature')
    parser.add_argument('--T', type=float, default=0.05, help='The Knowledge distillation temperature')
    parser.add_argument('--weight', type=int, default=1, help='The Weigth for the distillation loss')

    args, unknown = parser.parse_known_args()

    if args.parti_num is None:
        args.parti_num = sum(args.domains.values())

    return args

In [8]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
from datasets.utils import FederatedDataset
from models.utils.federated_model import FederatedModel
from utils.timeseries_detection import find_anomalies

def local_evaluate(nets : list,
                   train_dl: list[dict],
                   private_dataset: FederatedDataset,
                   group_detections: bool = False,
                  detect_anomalies : bool = True) -> list:

    labels_map = private_dataset.get_labels()
    map_clients = private_dataset.MAP_CLIENTS
    aux_latent = []

    for client, (net, dl) in enumerate(zip(nets, train_dl)):
        dl['ry_hat'], dl['y_hat'], dl['fy_hat'], dl['x_lat'] = net.predict(dl['X'])
        if not detect_anomalies:
            aux_latent.append(dl['x_lat'])
            continue
            
        dl['errors'] = net.compute_errors(dl['X'], dl['ry_hat'], dl['y_hat'], dl['fy_hat'])

        aux_anomaly = []
        for error in dl['errors']:
            detections = find_anomalies(
                errors=error,
                index=dl['index'],
                window_size_portion=0.35,
                window_step_size_portion=0.10,
                fixed_threshold=True,
                inverse=True,
                anomaly_padding=50,
                lower_threshold=True
            )
            if len(detections):
                aux_anomaly.append(detections)

        if len(aux_anomaly) == 0:
            aux_latent.append(dl['x_lat'])
            continue

        df_anomalies = pd.DataFrame(np.vstack(aux_anomaly), columns=['start', 'end', 'severity'])
        if group_detections:
            df_anomalies = group_anomalies(df_anomalies)

        labels_df = labels_map[map_clients[client]][private_dataset.scenario]
        process_anomalies(dl, df_anomalies, labels_df)

        aux_latent.append(dl['x_lat'])

    return aux_latent


def process_latent_df(df_latent, umap_neighbors=50, umap_min_dist=0.95, reduce_raw = False):
    # Add 'week' and 'hour' columns
    df_latent['month'] = df_latent['timestamp'].dt.month
    df_latent['hour'] = df_latent['timestamp'].dt.hour

    # Move 'week' and 'hour' next to 'timestamp'
    cols = df_latent.columns.tolist()
    timestamp_index = cols.index('timestamp')
    cols.remove('month')
    cols.remove('hour')
    cols.insert(timestamp_index + 1, 'month')
    cols.insert(timestamp_index + 2, 'hour')
    df_latent = df_latent[cols]

    # Get feature columns (assumed to be latent vectors)
    feature_cols = [col for col in df_latent.columns if col.startswith('x_')]

    # Original (unscaled) features
    X_raw = df_latent[feature_cols].values

    # Scaled features
    scaler = MinMaxScaler()
    X_scaled = scaler.fit_transform(X_raw)

    X_pca_scaled, X_umap_scaled = reduce_dims(
        X=X_scaled,
        method=None,
        n_components=2,
        umap_neighbors=umap_neighbors,
        umap_min_dist=umap_min_dist
    )

    df_pca_scaled = pd.DataFrame(X_pca_scaled, columns=['pca_0_scaled', 'pca_1_scaled'])
    df_umap_scaled = pd.DataFrame(X_umap_scaled, columns=['umap_0_scaled', 'umap_1_scaled'])

    if reduce_raw:
        # Dimensionality reduction
        X_pca_raw, X_umap_raw = reduce_dims(
            X=X_raw,
            method=None,
            n_components=2,
            umap_neighbors=umap_neighbors,
            umap_min_dist=umap_min_dist
        )
        # Create DataFrames
        df_pca_raw = pd.DataFrame(X_pca_raw, columns=['pca_0_raw', 'pca_1_raw'])
        df_umap_raw = pd.DataFrame(X_umap_raw, columns=['umap_0_raw', 'umap_1_raw'])

        return df_latent, (df_pca_raw, df_umap_raw), (df_pca_scaled, df_umap_scaled)

    return df_latent, df_pca_scaled, df_umap_scaled

In [9]:
args = parse_args()

args.extra_coments += "_0.2_20_history"

agg_int = int(args.interval_agg / 3600)
results_id = f'{args.experiment_id}_{args.communication_epoch}_{args.local_epoch}_{agg_int}_{args.window_size}_{args.extra_coments}'

results_path = f"results/results_{results_id}.pkl"
latent_path = f"results/latent_{results_id}.pkl"
results_path

'results/results_Pipeline_Full_medium_E_15_1_2_84_proto_month_0.2_20_history.pkl'

In [13]:
with open(results_path, 'rb') as f:
    results = pickle.load(f)

with open(latent_path, 'rb') as f:
    latent_dfs = pickle.load(f)

In [15]:
results['Baseline'].keys()

dict_keys(['lat', 'model'])

In [17]:
priv_dataset = get_private_dataset(args)

backbones_list = priv_dataset.get_backbone(
    parti_num=args.parti_num,
    names_list=None,
    n_series=args.input_size
)

In [19]:
# fed_w = results['Baseline']['model'].nets_list[0].state_dict()
# backbones_list[0].load_state_dict(fed_w)

In [42]:
# train_DL = priv_dataset.get_data_loaders()
# global_model_history = [5 * [history] for history in results['Baseline']['model'].weight_history]
# # nets = results['Baseline']['model'].nets_list

# cases = {}
# lat = local_evaluate(nets, train_DL, priv_dataset, False, False)

In [51]:
fed_w = results['Baseline']['model'].nets_list[0].state_dict()
backbones_list[0].load_state_dict(fed_w)

<All keys matched successfully>

In [83]:
train_DL = priv_dataset.get_data_loaders()
global_model_history = results['Baseline']['model'].weight_history

label_clients = [
    'District_A', 'District_B', 'District_C', 'District_D', 'District_E',
    'District_2A', 'District_2B', 'District_2C'
]

base_index = train_DL[0]['X_index']
latent_dfs_local = {}

scenario = 'Baseline'
epoch = 14
latent_dfs_local[scenario] = {}
aux_latents = []

for epoch in range(args.communication_epoch):
    state_dict = global_model_history[epoch]
    for net in backbones_list:
        net.load_state_dict(state_dict)

    lat = local_evaluate(backbones_list, train_DL, priv_dataset, False, False)
    for i, client in enumerate(lat):
        baseline_lat = create_latent_df(
            X_index=base_index,
            x_lat=client,
            label=f"{scenario}__{label_clients[i]}__{epoch}",
            is_unix=True
        )
        aux_latents.append(baseline_lat)

df_latent = pd.concat(aux_latents)
df_latent

Unnamed: 0,timestamp,label,x_0,x_1,x_2,x_3,x_4,x_5,x_6,x_7,...,x_30,x_31,x_32,x_33,x_34,x_35,x_36,x_37,x_38,x_39
0,1970-01-01 00:00:00,Baseline__District_A__0,-0.010584,0.024031,-0.054760,0.012503,-0.116096,0.186290,0.164527,0.171211,...,0.145238,0.160193,0.206590,-0.191187,-0.187953,0.193912,0.020373,0.176392,-0.012040,0.254001
1,1970-01-01 02:00:00,Baseline__District_A__0,0.007210,0.035698,-0.059576,-0.023484,-0.110416,0.181035,0.146886,0.172546,...,0.134260,0.158613,0.212083,-0.200150,-0.170326,0.199777,0.018015,0.176560,-0.008561,0.247585
2,1970-01-01 04:00:00,Baseline__District_A__0,-0.002227,0.064712,-0.066821,-0.007583,-0.116849,0.205037,0.152792,0.192280,...,0.128884,0.161716,0.211263,-0.197378,-0.179876,0.187486,0.024399,0.178399,0.007575,0.236477
3,1970-01-01 06:00:00,Baseline__District_A__0,-0.005221,0.068630,-0.068413,-0.009454,-0.116762,0.203665,0.141542,0.179634,...,0.130580,0.149984,0.222165,-0.193982,-0.193121,0.198939,0.036084,0.175410,-0.000019,0.270071
4,1970-01-01 08:00:00,Baseline__District_A__0,-0.015873,0.052474,-0.063953,0.009202,-0.123932,0.205712,0.147815,0.167174,...,0.135158,0.151144,0.231073,-0.190794,-0.271501,0.206593,0.041648,0.152544,0.002593,0.300308
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4232,1970-12-19 16:00:00,Baseline__District_E__14,0.083077,-0.178612,0.238339,0.241603,0.479855,-0.077359,0.150663,0.174203,...,0.038450,0.156507,0.166817,0.207140,0.030295,-0.037974,-0.488502,-0.160213,0.176935,0.217187
4233,1970-12-19 18:00:00,Baseline__District_E__14,-0.110276,0.246206,0.538833,0.340932,0.049162,0.005554,0.011533,0.182439,...,0.160744,0.311915,0.139415,0.183717,-0.069331,-0.131696,-0.445472,-0.234817,0.162345,0.216505
4234,1970-12-19 20:00:00,Baseline__District_E__14,0.125138,-0.144838,0.315174,0.241153,0.318795,-0.196034,0.113309,0.179579,...,0.018326,0.356105,0.186163,0.242739,-0.157743,-0.137031,-0.311968,-0.136737,0.288261,0.237323
4235,1970-12-19 22:00:00,Baseline__District_E__14,0.049569,-0.041379,0.007257,0.251374,0.280496,0.187211,0.053751,0.211509,...,-0.192353,0.119845,0.218454,0.239805,0.049464,-0.061338,-0.351593,-0.138016,0.334469,0.227859


In [99]:
zxc = []
for asd in latent_dfs['Baseline'].values():
    qwe = asd['latent_space'].copy()
    qwe.drop(columns = ['hour', 'week'], inplace = True)
    zxc.append(qwe)

final_qwe = pd.concat(zxc)
final_qwe

Unnamed: 0,timestamp,label,x_0,x_1,x_2,x_3,x_4,x_5,x_6,x_7,...,x_30,x_31,x_32,x_33,x_34,x_35,x_36,x_37,x_38,x_39
0,1970-01-01 00:00:00,Baseline__District_A__0,-0.010584,0.024031,-0.054760,0.012503,-0.116096,0.186290,0.164527,0.171211,...,0.145238,0.160193,0.206590,-0.191187,-0.187953,0.193912,0.020373,0.176392,-0.012040,0.254001
1,1970-01-01 02:00:00,Baseline__District_A__0,0.007210,0.035698,-0.059576,-0.023484,-0.110416,0.181035,0.146886,0.172546,...,0.134260,0.158613,0.212083,-0.200150,-0.170326,0.199777,0.018015,0.176560,-0.008561,0.247585
2,1970-01-01 04:00:00,Baseline__District_A__0,-0.002227,0.064712,-0.066821,-0.007583,-0.116849,0.205037,0.152792,0.192280,...,0.128884,0.161716,0.211263,-0.197378,-0.179876,0.187486,0.024399,0.178399,0.007575,0.236477
3,1970-01-01 06:00:00,Baseline__District_A__0,-0.005221,0.068630,-0.068413,-0.009454,-0.116762,0.203665,0.141542,0.179634,...,0.130580,0.149984,0.222165,-0.193982,-0.193121,0.198939,0.036084,0.175410,-0.000019,0.270071
4,1970-01-01 08:00:00,Baseline__District_A__0,-0.015873,0.052474,-0.063953,0.009202,-0.123932,0.205712,0.147815,0.167174,...,0.135158,0.151144,0.231073,-0.190794,-0.271501,0.206593,0.041648,0.152544,0.002593,0.300308
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4232,1970-12-19 16:00:00,Baseline__District_E__14,0.083077,-0.178612,0.238339,0.241603,0.479855,-0.077359,0.150663,0.174203,...,0.038450,0.156507,0.166817,0.207140,0.030295,-0.037974,-0.488502,-0.160213,0.176935,0.217187
4233,1970-12-19 18:00:00,Baseline__District_E__14,-0.110276,0.246206,0.538833,0.340932,0.049162,0.005554,0.011533,0.182439,...,0.160744,0.311915,0.139415,0.183717,-0.069331,-0.131696,-0.445472,-0.234817,0.162345,0.216505
4234,1970-12-19 20:00:00,Baseline__District_E__14,0.125138,-0.144838,0.315174,0.241153,0.318795,-0.196034,0.113309,0.179579,...,0.018326,0.356105,0.186163,0.242739,-0.157743,-0.137031,-0.311968,-0.136737,0.288261,0.237323
4235,1970-12-19 22:00:00,Baseline__District_E__14,0.049569,-0.041379,0.007257,0.251374,0.280496,0.187211,0.053751,0.211509,...,-0.192353,0.119845,0.218454,0.239805,0.049464,-0.061338,-0.351593,-0.138016,0.334469,0.227859


In [103]:
df_latent.equals(final_qwe)

True

In [None]:
df_latent, df_pca_raw, df_umap_raw, df_pca_scaled, df_umap_scaled = process_latent_df(
    df_latent,
    umap_neighbors=50,
    umap_min_dist=0.95
)

latent_dfs[scenario][epoch] = {
    'latent_space': df_latent,
    'pca_raw': df_pca_raw,
    'pca_scl': df_pca_scaled,
    'umap_raw': df_umap_raw,
    'umap_scl': df_umap_scaled
}

array([ 0.74661344,  0.09786752, -0.43723458, -0.13199764, -0.24898678,
        0.12381948,  0.6496835 ,  0.43976417, -0.12445526, -0.5409774 ,
        0.47943497,  0.5415784 ,  0.37453976, -0.0360894 ,  0.63657755,
        0.40045112, -0.2876662 , -0.48390603,  0.2058657 ,  0.33092305,
        0.72982144,  0.19135489, -0.39418715, -0.12652794, -0.25964922,
        0.13351041,  0.6474296 ,  0.5326093 , -0.11065613, -0.5486495 ,
        0.47979334,  0.5237646 ,  0.37276983,  0.08710243,  0.6568862 ,
        0.38777098, -0.35343802, -0.5196516 ,  0.19210084,  0.40662995],
      dtype=float32)

In [194]:
latent_dfs['Baseline'][14]['latent_space']

Unnamed: 0,timestamp,week,hour,label,x_0,x_1,x_2,x_3,x_4,x_5,...,x_30,x_31,x_32,x_33,x_34,x_35,x_36,x_37,x_38,x_39
0,1970-01-01 00:00:00,1969-12-29/1970-01-04,0,Baseline__District_A__14,-0.142941,-0.001657,-0.214057,-0.369423,-0.343319,0.041360,...,0.526157,0.549127,0.471029,-0.019304,0.666300,0.398162,-0.275491,-0.463293,0.191563,0.301749
1,1970-01-01 02:00:00,1969-12-29/1970-01-04,2,Baseline__District_A__14,0.746613,0.097868,-0.437235,-0.131998,-0.248987,0.123819,...,0.479793,0.523765,0.372770,0.087102,0.656886,0.387771,-0.353438,-0.519652,0.192101,0.406630
2,1970-01-01 04:00:00,1969-12-29/1970-01-04,4,Baseline__District_A__14,0.303491,0.112474,-0.417665,-0.223863,-0.263058,0.050022,...,0.563074,0.549353,0.550372,-0.046125,0.677228,0.396480,-0.173387,-0.431698,0.229754,0.227236
3,1970-01-01 06:00:00,1969-12-29/1970-01-04,6,Baseline__District_A__14,0.205533,0.028335,-0.469488,-0.202246,-0.256944,0.030829,...,0.466656,0.542350,0.360741,-0.001037,0.640295,0.380910,-0.259244,-0.503103,0.231032,0.325150
4,1970-01-01 08:00:00,1969-12-29/1970-01-04,8,Baseline__District_A__14,0.570260,0.233201,-0.305453,-0.124365,-0.272481,0.136685,...,0.513575,0.515573,0.365029,0.281146,0.660709,0.310505,-0.319252,-0.489226,0.170169,0.341535
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4232,1970-12-19 16:00:00,1970-12-14/1970-12-20,16,Baseline__District_E__14,0.068341,-0.023849,-0.187381,-0.177456,0.493320,0.042576,...,-0.165605,-0.110583,-0.140423,0.042948,-0.236380,0.151115,0.054370,0.118439,-0.078135,0.245746
4233,1970-12-19 18:00:00,1970-12-14/1970-12-20,18,Baseline__District_E__14,0.077923,0.043670,-0.080081,-0.103544,0.496764,0.177605,...,-0.177590,-0.053128,-0.116437,0.030316,-0.179021,0.134711,0.106423,0.020251,-0.154162,0.097579
4234,1970-12-19 20:00:00,1970-12-14/1970-12-20,20,Baseline__District_E__14,0.088896,-0.036994,-0.064206,-0.118546,0.495774,0.127977,...,-0.189935,-0.051192,-0.171383,0.101887,-0.129168,0.062005,0.043601,0.168497,-0.195081,0.111074
4235,1970-12-19 22:00:00,1970-12-14/1970-12-20,22,Baseline__District_E__14,0.081692,0.037977,-0.057168,-0.040005,0.428567,0.033290,...,-0.168693,-0.169162,-0.112970,0.104904,-0.416869,0.215599,-0.079609,0.333224,-0.037302,0.213235
