In [7]:
import warnings
from argparse import ArgumentParser
import pandas as pd
import pickle

import subprocess
import sys
import os

from datasets import Priv_NAMES as DATASET_NAMES
from datasets import get_private_dataset
from models import get_all_models, get_model
from utils.Server import train
from utils.Toolbox_analysis import create_latent_df, process_latent_df
from utils.Toolbox_visualization import format_latent_dict, load_and_scale_data, combine_latents, plot_latent_heatmap, plot_time_series_and_latents
warnings.simplefilter(action='ignore', category=FutureWarning)

def parse_args():
    parser = ArgumentParser(description='You Only Need Me', allow_abbrev=False)
    parser.add_argument('--device_id', type=int, default=0, help='The Device Id for Experiment')
    parser.add_argument('--run_simulation', type=bool, default=True, help='The Device Id for Experiment')
    parser.add_argument('--detect_anomalies', type=bool, default=False)
    parser.add_argument('--generate_viz', type=bool, default=True, help='Creates and saves interactive visualizations')


    # Communication - epochs
    parser.add_argument('--communication_epoch', type=int, default=5,
                        help='The Communication Epoch in Federated Learning')
    parser.add_argument('--local_epoch', type=int, default=3, help='The Local Epoch for each Participant')

    # Participants info
    parser.add_argument('--parti_num', type=int, default=None, help='The Number for Participants. If "None" will be setted as the sum of values described in --domain')
    parser.add_argument('--online_ratio', type=float, default=1, help='The Ratio for Online Clients')

    # Data parameter
    parser.add_argument('--dataset', type=str, default='fl_leaks', choices=DATASET_NAMES, help='Which scenario to perform experiments on.')
    parser.add_argument('--experiment_id', type=str, default='Pipeline_Full', help='Experiment identifier')
    parser.add_argument('--extra_coments', type=str, default='proto_month', help='Aditional info')
    parser.add_argument('--domains', type=dict, default={
                                                        'Graeme': 5,
                                                        # 'Balerma': 3,
                                                        },
                        help='Domains and respective number of participants.')

    ## Time series preprocessing
    parser.add_argument('--interval_agg', type=int, default=2 * 60 ** 2,
                        help='Agregation interval (seconds) of time series')
    parser.add_argument('--window_size', type=int, default=84, help='Rolling window length')

    # Model (AER) parameters
    parser.add_argument('--input_size', type=int, default=5, help='Number of sensors')  #TODO adaptar
    parser.add_argument('--output_size', type=int, default=5, help='Shape output - dense layer')
    parser.add_argument('--lstm_units', type=int, default=30,
                        help='Number of LSTM units (the latent space will have dimension 2 times bigger')
    

    # Federated parameters
    parser.add_argument('--model', type=str, default='fpl', help='Federated Model name.', choices=get_all_models()) #fedavg

    parser.add_argument('--structure', type=str, default='homogeneity')

    parser.add_argument('--pri_aug', type=str, default='weak',  # weak strong
                        help='Augmentation for Private Data')
    parser.add_argument('--learning_decay', type=bool, default=False, help='The Option for Learning Rate Decay')
    parser.add_argument('--averaging', type=str, default='weight', help='The Option for averaging strategy')

    parser.add_argument('--infoNCET', type=float, default=0.02, help='The InfoNCE temperature')
    parser.add_argument('--T', type=float, default=0.05, help='The Knowledge distillation temperature')
    parser.add_argument('--weight', type=int, default=1, help='The Weigth for the distillation loss')

    args, unknown = parser.parse_known_args()

    if args.parti_num is None:
        args.parti_num = sum(args.domains.values())

    return args

args = parse_args()

In [9]:
agg_int = 2
results_id = f'{args.experiment_id}_{args.communication_epoch}_{args.local_epoch}_{agg_int}_{args.window_size}_{args.extra_coments}'

results_path = f"results/results_{results_id}.pkl"
latent_path = f"results/{results_id}.pkl"


with open(results_path, 'rb') as f:
    results = pickle.load(f)

In [29]:
results['Baseline']['model'].local_history.keys()

dict_keys([0, 1, 2, 3, 4])

In [42]:
len(results['Baseline']['model'].global_history[0][0])

2

In [53]:
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
import umap
import numpy as np
import pandas as pd


def flatten_prototypes(local_history):
    """Flatten local_history into a DataFrame with prototype features as columns."""
    records = []

    for epoch, clients in local_history.items():
        for client_id, client_data in enumerate(clients):
            for label, prototype in client_data.items():
                prototype = np.array(prototype)
                record = {
                    'epoch': epoch,
                    'client_id': client_id,
                    'label': label,
                }
                for i, val in enumerate(prototype):
                    record[f'feature_{i}'] = val
                records.append(record)

    df = pd.DataFrame(records)
    return df


def normalize_by_epoch(df, feature_cols):
    """
    Normalize prototype features using MinMaxScaler separately for each epoch.
    
    Returns:
    - df_normalized: DataFrame with normalized features
    - scalers_by_epoch: dict of epoch -> fitted MinMaxScaler
    """
    scalers_by_epoch = {}
    df_list = []

    for epoch, group in df.groupby("epoch"):
        scaler = MinMaxScaler()
        scaled_features = scaler.fit_transform(group[feature_cols])
        scaled_df = group.copy()
        scaled_df[feature_cols] = scaled_features
        df_list.append(scaled_df)
        scalers_by_epoch[epoch] = scaler

    df_normalized = pd.concat(df_list, ignore_index=True)
    return df_normalized, scalers_by_epoch


def reduce_dims(X, method=None, n_components=2, umap_neighbors=50, umap_min_dist=0.95):
    """
    Applies PCA and UMAP to the input data.

    Returns:
    - X_pca: PCA-reduced data
    - X_umap: UMAP-reduced data
    """
    pca = PCA(n_components=n_components)
    X_pca = pca.fit_transform(X)

    if method == 'PCA':
        return X_pca, None

    reducer = umap.UMAP(n_components=n_components, n_neighbors=umap_neighbors, min_dist=umap_min_dist)
    X_umap = reducer.fit_transform(X)

    if method == 'UMAP':
        return None, X_umap

    return X_pca, X_umap


def process_prototypes(local_history, method=None):
    """
    Full pipeline:
    - Flatten
    - Normalize (MinMax) per epoch
    - Dimensionality reduction

    Returns:
    - df_final: DataFrame with reduced dimensions
    - scalers_by_epoch: dict of epoch -> MinMaxScaler
    """
    df = flatten_prototypes(local_history)
    feature_cols = [col for col in df.columns if col.startswith("feature_")]

    # MinMax scale per epoch
    df_normalized, scalers_by_epoch = normalize_by_epoch(df, feature_cols)

    # Dimensionality reduction
    X = df_normalized[feature_cols].values
    X_pca, X_umap = reduce_dims(X, method=method)

    # Append to DataFrame
    if X_pca is not None:
        df_normalized[['pca_0', 'pca_1']] = X_pca
    if X_umap is not None:
        df_normalized[['umap_0', 'umap_1']] = X_umap

    return df_normalized, scalers_by_epoch

In [59]:
local_history = results['Baseline']['model'].local_history
df_final, scalers_by_epoch = process_prototypes(local_history, method=None)
df_final

Unnamed: 0,epoch,client_id,label,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,...,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,pca_0,pca_1,umap_0,umap_1
0,0,0,9,0.798466,0.157108,0.290094,0.413068,0.754379,0.587829,0.138198,...,0.832880,0.364022,0.092465,0.000000,0.000000,0.074447,-0.611373,-0.004831,3.509451,10.060936
1,0,0,10,0.835216,0.145559,0.000000,0.454378,0.669226,0.433798,0.108227,...,1.000000,0.000000,0.000000,0.498345,0.224783,0.284221,-0.743875,0.076317,3.523551,10.046944
2,0,0,2,0.899234,0.126378,0.084828,0.449107,0.865067,0.590768,0.104950,...,0.917740,0.205260,0.093297,0.216481,0.146778,0.160602,-0.679403,-0.011020,3.585328,9.984883
3,0,0,1,0.739930,0.039756,0.132628,0.371167,0.540614,0.756935,0.128312,...,0.952065,0.227770,0.046549,0.334727,0.102875,0.249288,-0.641394,0.219046,3.685874,9.884972
4,0,0,3,0.986366,0.222766,0.269221,0.439561,0.975105,0.517488,0.093716,...,0.795736,0.342447,0.171884,0.025667,0.063047,0.000000,-0.566910,-0.057515,3.571238,9.999067
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,4,4,11,0.451066,0.813788,0.425699,0.969853,0.593466,0.132547,0.131473,...,0.112432,0.958564,0.909109,0.202094,0.251361,0.838744,1.780319,-1.315059,-6.807067,17.664486
296,4,4,10,0.242193,0.789805,0.504517,0.901965,0.429336,0.307541,0.067759,...,0.047195,0.716454,0.981581,0.325262,0.405568,0.936202,1.859067,-1.290245,-6.808478,17.673836
297,4,4,4,0.280435,0.845998,0.494017,0.869640,0.342759,0.179854,0.122747,...,0.073716,0.707630,0.940123,0.373518,0.306244,0.874687,1.873862,-1.361971,-6.816955,17.680876
298,4,4,6,0.477708,0.788367,0.485855,0.907886,0.379549,0.191667,0.110254,...,0.024102,0.773703,0.987158,0.213067,0.212035,0.893856,1.828908,-1.321961,-6.821355,17.684061


In [68]:
import altair as alt

def plot_prototypes(df, method='umap'):
    """
    Plots reduced prototype data using Altair with interactive filters for epoch and label.

    Parameters:
    - df: DataFrame from process_prototypes
    - method: 'umap' or 'pca'
    """
    assert method in ['umap', 'pca'], "method must be 'umap' or 'pca'"
    
    x_col = f'{method}_0'
    y_col = f'{method}_1'

    # Epoch dropdown selector
    epoch_selector = alt.binding_select(options=sorted(df['epoch'].unique()), name="Epoch")
    epoch_selection = alt.selection_point(
        fields=['epoch'],
        bind=epoch_selector
    )

    # Label dropdown selector
    label_selector = alt.binding_select(options=sorted(df['label'].unique()), name="Label")
    label_selection = alt.selection_point(
        fields=['label'],
        bind=label_selector
    )

    # Base chart
    chart = alt.Chart(df).mark_point(filled=True, size=100).encode(
        x=alt.X(x_col, title=f"{method.upper()} 1"),
        y=alt.Y(y_col, title=f"{method.upper()} 2"),
        color=alt.Color('label:N', title='Label'),
        shape=alt.Shape('client_id:N', title='Client ID'),
        tooltip=['epoch', 'client_id', 'label']
    ).add_params(
        epoch_selection,
        label_selection
    ).transform_filter(
        epoch_selection
    ).transform_filter(
        label_selection
    ).properties(
        width=600,
        height=400,
        title=f'{method.upper()} Projection of Prototypes (Interactive)'
    ).interactive()  # enables zoom and pan

    return chart

In [86]:
def compute_distance_lines(df, x_col, y_col):
    """
    Compute pairwise distances between points,
    grouped by epoch and label (filtered view).
    """
    lines = []

    for (epoch, label), group in df.groupby(['epoch', 'label']):
        coords = group[[x_col, y_col]].values
        indices = group.index.values

        for (i, j) in combinations(range(len(coords)), 2):
            x1, y1 = coords[i]
            x2, y2 = coords[j]
            dist = np.linalg.norm(coords[i] - coords[j])
            lines.append({
                'epoch': epoch,
                'label': label,
                'x1': x1, 'y1': y1,
                'x2': x2, 'y2': y2,
                'distance': dist
            })

    return pd.DataFrame(lines)
def plot_prototypes_with_distances(df, method='umap'):
    """
    Interactive Altair plot with:
    - Epoch slider
    - Label dropdown
    - Distance lines only within selected epoch and label
    """
    assert method in ['umap', 'pca'], "method must be 'umap' or 'pca'"
    x_col = f'{method}_0'
    y_col = f'{method}_1'

    # Precompute distances grouped by epoch + label
    df_lines = compute_distance_lines(df, x_col, y_col)

    # Interactive widgets
    epoch_slider = alt.binding_range(min=int(df['epoch'].min()),
                                     max=int(df['epoch'].max()),
                                     step=1, name="Epoch")
    epoch_selection = alt.selection_point(fields=['epoch'], bind=epoch_slider)

    label_selector = alt.binding_range(min=int(df['label'].min()),
                                     max=int(df['label'].max()),
                                     step=1, name="Label")
    label_selection = alt.selection_point(fields=['label'], bind=label_selector)

    # Distance lines
    line_chart = alt.Chart(df_lines).mark_line(opacity=0.2).encode(
        x='x1:Q', y='y1:Q',
        x2='x2:Q', y2='y2:Q',
        strokeWidth=alt.StrokeWidth('distance:Q', 
            scale=alt.Scale(domain=[0, df_lines['distance'].max()], range=[0.5, 5])),
        tooltip=['distance']
    ).transform_filter(
        epoch_selection
    ).transform_filter(
        label_selection
    )

    # Prototype scatter points
    point_chart = alt.Chart(df).mark_point(filled=True, size=100).encode(
        x=alt.X(x_col, title=f"{method.upper()} 1"),
        y=alt.Y(y_col, title=f"{method.upper()} 2"),
        color=alt.Color('label:N', title='Label'),
        shape=alt.Shape('client_id:N', title='Client ID'),
        tooltip=['epoch', 'client_id', 'label']
    ).transform_filter(
        epoch_selection
    ).transform_filter(
        label_selection
    )

    # Combine
    chart = (line_chart + point_chart).add_params(
        epoch_selection,
        label_selection
    ).properties(
        width=300,
        height=300,
        title=f'{method.upper()} Projection with Pairwise Distances (Filtered)'
    ).interactive()

    return chart


In [88]:
plot = plot_prototypes_with_distances(df_final, method='pca')
plot.show()  # or just `plot` in Jupyter
