# Single Cluster Federated MTL Setup
This notebook contains all setup and configuration for single cluster federated multi-task learning.
Ready for multi-cluster and CH compromisation experiments.

## 1. Imports and Data Loading

In [None]:
import numpy as np
import random
import os
import flwr as fl
import ray
import tensorflow as tf
from tensorflow import keras

tf.keras.backend.clear_session()

import pandas as pd
%matplotlib inline
df = pd.read_csv('datasets/local_cache/dataset_12500_samples_65_features.csv')

# Drop features with high label leakage
cols_to_drop = [
 'fin_flag_cnt', 'syn_flag_cnt', 'rst_flag_cnt', 'psh_flag_cnt',
 'ack_flag_cnt', 'urg_flag_cnt', 'cwe_flag_cnt', 'ece_flag_cnt',
 'fwd_header_length', 'bwd_header_length',
 'active_mean', 'active_std', 'active_max', 'active_min',
 'idle_mean', 'idle_std', 'idle_max', 'idle_min',
 'subflow_fwd_bytes'
]
df = df.drop(columns=[c for c in cols_to_drop if c in df.columns])

print(f" Data loaded: {len(df)} samples, {len(df.columns)} features")

 Data loaded: 12500 samples, 48 features


In [2]:
seed = 42

random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['PYTHONHASHSEED'] = str(seed)

print(" Random seeds set for reproducibility")

 Random seeds set for reproducibility


## 3. Configuration

In [3]:
CFG = {
    # Training parameters
    'local_epochs': 1,
    'lr': 1e-3,
    'loss_weights': {'traffic': 1, 'duration': 1, 'bandwidth': 1},
    'test_size': 0.2,
    
    # Client configuration
    'n_clients_flat': 600,
    'n_clusters': 3,
    'clients_per_cluster': 200,
    'client_frac': 1.0,  # 100% client participation
    
    # Hierarchical FL
    'global_aggregator_cluster': 1,  # Cluster 1 performs global aggregation
    
    # Data distribution (TWO-LEVEL SPLIT)
    'cluster_split': 'equal',      # How to split data among clusters ('equal' or 'dirichlet')
    'client_split': 'dirichlet',   # How to split data among clients within clusters (always 'dirichlet')
    'alpha_client': 0.4,           # Dirichlet α for client-level distribution
    'alpha_cluster': 0.4,          # Dirichlet α for cluster-level distribution (when cluster_split='dirichlet')
}

print("Configuration:")
print(f"  Training: local_epochs={CFG['local_epochs']}, lr={CFG['lr']}")
print(f"  Clients: {CFG['n_clients_flat']} total ({CFG['n_clusters']} clusters × {CFG['clients_per_cluster']} clients)")
print(f"  Participation: {CFG['client_frac']*100}%")
print(f"\n  Data Distribution (Two-Level):")
print(f"    Level 1 (Clusters): {CFG['cluster_split']} split")
print(f"    Level 2 (Clients): {CFG['client_split']} split (α={CFG['alpha_client']})")
print(f"    Cluster α: {CFG['alpha_cluster']} (used when cluster_split='dirichlet')")

Configuration:
  Training: local_epochs=1, lr=0.001
  Clients: 600 total (3 clusters × 200 clients)
  Participation: 100.0%

  Data Distribution (Two-Level):
    Level 1 (Clusters): equal split
    Level 2 (Clients): dirichlet split (α=0.4)
    Cluster α: 0.4 (used when cluster_split='dirichlet')


## 4. Feature Selection for Each Task

In [4]:
# Define features to exclude for each task (prevent label leakage)
exclude_traffic = [
 'src_ip', 'dst_ip', 'src_port', 'dst_port', # identity → leakage
 'protocol', # not useful for QUIC-only
 'label', 'flow_duration', 'flow_bytes_per_s', 'bandwidth_bps'
]

exclude_duration = [
 'src_ip', 'dst_ip', 'src_port', 'dst_port',
 'protocol',
 'label', 'flow_duration', 'flow_bytes_per_s', 'bandwidth_bps'
]

exclude_bandwidth = [
 'src_ip', 'dst_ip', 'src_port', 'dst_port',
 'protocol',
 'label', 'flow_duration', 'flow_bytes_per_s', 'bandwidth_bps'
]

Xcols_traffic = [col for col in df.columns if col not in exclude_traffic]
Xcols_duration = [col for col in df.columns if col not in exclude_duration]
Xcols_bandwidth = [col for col in df.columns if col not in exclude_bandwidth]

print(f" Traffic features: {len(Xcols_traffic)}")
print(f" Duration features: {len(Xcols_duration)}")
print(f" Bandwidth features: {len(Xcols_bandwidth)}")

 Traffic features: 39
 Duration features: 39
 Bandwidth features: 39


## 5. Train/Test Split

In [5]:
from sklearn.model_selection import train_test_split

n = len(df)
indices = np.arange(n)
train_idx, test_idx = train_test_split(
 indices, 
 test_size=CFG['test_size'], 
 random_state=seed, 
 shuffle=True
)

train_df = df.iloc[train_idx].copy()
test_df = df.iloc[test_idx].copy()

print(f" Train samples: {len(train_df)}")
print(f" Test samples: {len(test_df)}")

 Train samples: 10000
 Test samples: 2500


## 6. Winsorization (Outlier Handling)

In [6]:
from scipy.stats.mstats import winsorize

all_features = list(set(Xcols_traffic + Xcols_duration + Xcols_bandwidth))

# Calculate winsorization bounds from training data
winsor_bounds = {}
for col in all_features:
    if col in train_df.columns:
     lower = train_df[col].quantile(0.01)
     upper = train_df[col].quantile(0.99)
     winsor_bounds[col] = (lower, upper)

# Apply winsorization
for col, (lower, upper) in winsor_bounds.items():
 lower_limit = (train_df[col] < lower).mean()
 upper_limit = (train_df[col] > upper).mean()
 
 for df_temp in [train_df, test_df]:
     df_temp[col] = winsorize(df_temp[col], limits=(lower_limit, upper_limit))

## 7. Target Variable Processing

In [7]:
# Create quantile-based labels for duration and bandwidth (5 classes each)
y_dur_raw_train = train_df['flow_duration'].values
y_bw_raw_train = train_df['bandwidth_bps'].values

# Log-transform
bw_log = np.log1p(y_bw_raw_train)
dur_log = np.log1p(y_dur_raw_train)

# Compute 5-bin quantiles (20%, 40%, 60%, 80%)
bw_quantiles = np.quantile(bw_log, [0.20, 0.40, 0.60, 0.80])
dur_quantiles = np.quantile(dur_log, [0.20, 0.40, 0.60, 0.80])

def create_quantile_labels(raw_values, quantiles):
 """Create 5-class labels (0-4) using quantile thresholds"""
 v = np.log1p(raw_values)
 labels = np.digitize(v, quantiles, right=False) # returns 0..4
 return labels

print(" Quantile thresholds computed for duration and bandwidth")

 Quantile thresholds computed for duration and bandwidth


## 8. Label Encoding and Feature Scaling

In [8]:
from sklearn.preprocessing import LabelEncoder, StandardScaler

# Create labels for all tasks
y_dur_train = create_quantile_labels(train_df['flow_duration'].values, dur_quantiles)
y_dur_test = create_quantile_labels(test_df['flow_duration'].values, dur_quantiles)

y_bw_train = create_quantile_labels(train_df['bandwidth_bps'].values, bw_quantiles)
y_bw_test = create_quantile_labels(test_df['bandwidth_bps'].values, bw_quantiles)

# Traffic classification (label encoding)
le_traf = LabelEncoder()
y_traf_train = le_traf.fit_transform(train_df['label'])
y_traf_test = le_traf.transform(test_df['label'])

# Standardize features
feature_scaler = StandardScaler()
train_df[all_features] = feature_scaler.fit_transform(train_df[all_features])
test_df[all_features] = feature_scaler.transform(test_df[all_features])

print(" Labels created and features standardized")
print(f" Traffic classes: {len(np.unique(y_traf_train))}")
print(f" Duration classes: {len(np.unique(y_dur_train))}")
print(f" Bandwidth classes: {len(np.unique(y_bw_train))}")

 Labels created and features standardized
 Traffic classes: 5
 Duration classes: 5
 Bandwidth classes: 5


## 9. Prepare Feature Matrices

In [9]:
# Extract feature matrices for each task
X_traffic_train = train_df[Xcols_traffic].values
X_duration_train = train_df[Xcols_duration].values
X_bandwidth_train = train_df[Xcols_bandwidth].values

X_traffic_test = test_df[Xcols_traffic].values
X_duration_test = test_df[Xcols_duration].values
X_bandwidth_test = test_df[Xcols_bandwidth].values

print(" Feature matrices extracted")
print(f" Traffic: {X_traffic_train.shape}")
print(f" Duration: {X_duration_train.shape}")
print(f" Bandwidth: {X_bandwidth_train.shape}")

 Feature matrices extracted
 Traffic: (10000, 39)
 Duration: (10000, 39)
 Bandwidth: (10000, 39)


## 10. Feature Padding (Uniform Dimensionality)

In [10]:
# Pad all feature matrices to the same dimension
max_dim = max(X_traffic_train.shape[1], X_duration_train.shape[1], X_bandwidth_train.shape[1])

def pad_features(X, target_size):
    """Pad features with zeros to reach target size"""
    if X.shape[1] < target_size:
        padding = np.zeros((X.shape[0], target_size - X.shape[1]))
        return np.concatenate([X, padding], axis=1)
    return X

X_traffic_train = pad_features(X_traffic_train, max_dim)
X_duration_train = pad_features(X_duration_train, max_dim)
X_bandwidth_train = pad_features(X_bandwidth_train, max_dim)

X_traffic_test = pad_features(X_traffic_test, max_dim)
X_duration_test = pad_features(X_duration_test, max_dim)
X_bandwidth_test = pad_features(X_bandwidth_test, max_dim)

print(f"✓ All features padded to dimension: {max_dim}")
print(f"  Traffic: {X_traffic_train.shape}")
print(f"  Duration: {X_duration_train.shape}")
print(f"  Bandwidth: {X_bandwidth_train.shape}")

✓ All features padded to dimension: 39
  Traffic: (10000, 39)
  Duration: (10000, 39)
  Bandwidth: (10000, 39)


## 11. Mutual Information Analysis (Feature Leakage Check)

In [77]:
from sklearn.feature_selection import mutual_info_classif
import sys

def find_high_mi_features(X_cols, y_train, train_df, task_name, seed, threshold=0.2):
    """Find features with high mutual information (potential label leakage)"""
    X_train = train_df[X_cols].values
    
    try:
        mi_scores = mutual_info_classif(
            X_train,
            y_train,
            discrete_features=False,
            random_state=seed
        )
    except Exception as e:
        print(f"Error calculating MI for {task_name}: {e}", file=sys.stderr)
        return []
    
    mi_results = dict(zip(X_cols, mi_scores))
    
    problematic = []
    for feat, mi in mi_results.items():
        if mi > threshold:
            problematic.append((feat, mi))
    
    if problematic:
        print(f"\n{task_name} - Found {len(problematic)} features with MI > {threshold}:")
        problematic.sort(key=lambda x: x[1], reverse=True)
        for feat, mi in problematic[:10]:  # Show top 10
            print(f"  • {feat}: {mi:.4f}")
        if len(problematic) > 10:
            print(f"  ... and {len(problematic) - 10} more")
    else:
        print(f"\n{task_name} - No features found with MI > {threshold}")
    
    return problematic

print("Analyzing mutual information between features and labels...")

problematic_dur = find_high_mi_features(
    Xcols_duration, y_dur_train, train_df, 'Duration', seed
)

problematic_bw = find_high_mi_features(
    Xcols_bandwidth, y_bw_train, train_df, 'Bandwidth', seed
)

problematic_tf = find_high_mi_features(
    Xcols_traffic, y_traf_train, train_df, 'Traffic', seed
)

all_diagnostics = {
    'duration': problematic_dur,
    'bandwidth': problematic_bw,
    'traffic': problematic_tf
}

print("\n✓ Mutual information analysis complete")
print("  Note: High MI features may indicate correlation with labels")

Analyzing mutual information between features and labels...

Duration - Found 30 features with MI > 0.2:
  • bwd_packets_per_s: 0.9275
  • flow_packets_per_s: 0.9051
  • fwd_packets_per_s: 0.8397
  • flow_iat_mean: 0.7458
  • flow_iat_max: 0.7057
  • flow_iat_std: 0.6616
  • fwd_iat_total: 0.5753
  • fwd_iat_std: 0.5706
  • fwd_iat_max: 0.5513
  • fwd_iat_mean: 0.5258
  ... and 20 more

Bandwidth - Found 30 features with MI > 0.2:
  • bwd_packets_per_s: 1.2198
  • flow_packets_per_s: 1.1862
  • fwd_packets_per_s: 1.0154
  • flow_iat_mean: 0.7372
  • flow_iat_max: 0.6862
  • flow_iat_std: 0.6665
  • fwd_iat_std: 0.5595
  • fwd_iat_mean: 0.5562
  • fwd_iat_max: 0.5272
  • fwd_iat_total: 0.5123
  ... and 20 more

Traffic - Found 35 features with MI > 0.2:
  • bwd_pkt_len_min: 1.1980
  • bwd_seg_size_min: 1.1928
  • fwd_seg_size_min: 1.0199
  • fwd_pkt_len_min: 1.0194
  • bwd_pkt_len_max: 0.9857
  • fwd_pkt_len_max: 0.7743
  • flow_rate_entropy: 0.6934
  • bwd_pkt_len_mean: 0.6692
  • tota

## 12. Client Partitioning

In [12]:
def build_client_partitions(cluster_split='equal', client_split='dirichlet', verbose=True):
    """
    Build client partitions with TWO-LEVEL data distribution:
    - Level 1: Distribute data among CLUSTERS (equal or dirichlet)
    - Level 2: Distribute each cluster's data among CLIENTS (dirichlet)
    
    Args:
        cluster_split: 'equal' or 'dirichlet' - how to split data among clusters
        client_split: 'dirichlet' - how to split data among clients within clusters
        verbose: Print statistics
    
    Returns:
        client_indices_flat: List of client data indices
        client_index_to_cluster: Dict mapping client idx to cluster id
    """
    n_clients = CFG['n_clients_flat']
    n_clusters = CFG['n_clusters']
    clients_per_cluster = CFG['clients_per_cluster']
    alpha_client = CFG['alpha_client']
    alpha_cluster = CFG['alpha_cluster']
    min_size = 50
    rng = np.random.default_rng(seed)
    
    train_indices = np.arange(len(y_traf_train))
    labels = np.unique(y_traf_train)
    
    # Cluster_level_split
    
    if cluster_split == 'equal':
        # Equal split: each cluster gets 1/n_clusters of data
        samples_per_cluster = len(train_indices) // n_clusters
        cluster_indices = []
        
        for cluster_id in range(n_clusters):
            start_idx = cluster_id * samples_per_cluster
            end_idx = start_idx + samples_per_cluster if cluster_id < n_clusters - 1 else len(train_indices)
            cluster_indices.append(train_indices[start_idx:end_idx])
    
    elif cluster_split == 'dirichlet':
        # Dirichlet split: non-IID distribution among clusters
        cluster_bins = [[] for _ in range(n_clusters)]
        label_indices = {}
        
        for lbl in labels:
            label_indices[lbl] = train_indices[y_traf_train == lbl]
        
        for lbl in labels:
            idxs = label_indices[lbl]
            rng.shuffle(idxs)
            proportions = rng.dirichlet([alpha_cluster] * n_clusters)
            cuts = (np.cumsum(proportions) * len(idxs)).astype(int)
            parts = np.split(idxs, cuts[:-1])
            
            for cluster_id, part in enumerate(parts):
                cluster_bins[cluster_id].extend(part.tolist())
        
        cluster_indices = [np.array(sorted(set(cluster_bins[i])), dtype=int) for i in range(n_clusters)]
    
    else:
        raise ValueError(f"Unknown cluster_split: {cluster_split}")

    # Client_level_split
    
    client_indices_flat = []
    client_index_to_cluster = {}
    
    for cluster_id, cluster_data_indices in enumerate(cluster_indices):
        # Get labels for this cluster's data
        cluster_labels = y_traf_train[cluster_data_indices]
        unique_cluster_labels = np.unique(cluster_labels)
        
        # Build client bins for this cluster using Dirichlet
        client_bins = [[] for _ in range(clients_per_cluster)]
        
        for lbl in unique_cluster_labels:
            # Get indices within cluster that have this label
            lbl_mask = cluster_labels == lbl
            lbl_indices = cluster_data_indices[lbl_mask]
            
            if len(lbl_indices) > 0:
                rng.shuffle(lbl_indices)
                proportions = rng.dirichlet([alpha_client] * clients_per_cluster)
                cuts = (np.cumsum(proportions) * len(lbl_indices)).astype(int)
                parts = np.split(lbl_indices, cuts[:-1])
                
                for local_client_id, part in enumerate(parts):
                    client_bins[local_client_id].extend(part.tolist())
        
        # Create clients for this cluster
        for local_client_id in range(clients_per_cluster):
            client_data = np.array(sorted(set(client_bins[local_client_id])), dtype=int)
            
            # Ensure minimum size
            if len(client_data) < min_size:
                need = min_size - len(client_data)
                # Sample from cluster's data
                available = list(set(cluster_data_indices) - set(client_data))
                if len(available) >= need:
                    extra = rng.choice(available, size=need, replace=False)
                else:
                    extra = rng.choice(cluster_data_indices, size=need, replace=True)
                client_data = np.concatenate([client_data, extra])
                client_data = np.unique(client_data).astype(int)
            
            global_client_id = cluster_id * clients_per_cluster + local_client_id
            client_indices_flat.append(client_data.astype(int))
            client_index_to_cluster[global_client_id] = cluster_id
    
    # Statistics
    
    if verbose:
        print(f"\n✓ Created {len(client_indices_flat)} clients")
        print(f"  Cluster split: {cluster_split}")
        print(f"  Client split: {client_split}")
        print(f"  Sample sizes: min={min([len(c) for c in client_indices_flat])}, "
              f"max={max([len(c) for c in client_indices_flat])}, "
              f"avg={np.mean([len(c) for c in client_indices_flat]):.1f}")
        
        print("\n  Cluster distribution:")
        for cluster_id in range(n_clusters):
            cluster_clients = [i for i in range(n_clients) if client_index_to_cluster[i] == cluster_id]
            cluster_samples = sum(len(client_indices_flat[i]) for i in cluster_clients)
            print(f"    Cluster {cluster_id}: {len(cluster_clients)} clients, {cluster_samples} samples")
        
        print("\n  Sample client label distributions:")
        for i in range(min(3, len(client_indices_flat))):
            indices = client_indices_flat[i]
            labels_count = {}
            for lbl in labels:
                count = np.sum(y_traf_train[indices] == lbl)
                if count > 0:
                    labels_count[int(lbl)] = int(count)
            print(f"    Client {i} (Cluster {client_index_to_cluster[i]}): {labels_count}")
    
    return client_indices_flat, client_index_to_cluster
# Build clients with specified split type
client_indices_flat, client_index_to_cluster = build_client_partitions(
    cluster_split=CFG['cluster_split'],  # ✓ NEW: equal or dirichlet for clusters
    client_split=CFG['client_split'],    # ✓ NEW: dirichlet for clients
    verbose=True
)



✓ Created 600 clients
  Cluster split: equal
  Client split: dirichlet
  Sample sizes: min=50, max=92, avg=50.2

  Cluster distribution:
    Cluster 0: 200 clients, 10016 samples
    Cluster 1: 200 clients, 10067 samples
    Cluster 2: 200 clients, 10061 samples

  Sample client label distributions:
    Client 0 (Cluster 0): {0: 16, 1: 9, 2: 9, 3: 7, 4: 9}
    Client 1 (Cluster 0): {0: 16, 1: 11, 2: 8, 3: 9, 4: 6}
    Client 2 (Cluster 0): {0: 8, 1: 12, 2: 5, 3: 13, 4: 12}


In [13]:
class ClientData:
    """Container for client data and metadata"""
    def __init__(self, data_dict, cluster_id):
        self.ds = data_dict
        self.cluster_id = cluster_id


# Create client objects
clients = []

for i, indices in enumerate(client_indices_flat):

    # ensure numpy integer index array
    indices = np.asarray(indices, dtype=np.int32)

    # Slice features
    X_traffic_client   = X_traffic_train[indices].astype(np.float32).copy()
    X_duration_client  = X_duration_train[indices].astype(np.float32).copy()
    X_bandwidth_client = X_bandwidth_train[indices].astype(np.float32).copy()

    # Slice labels
    y_traffic_client   = y_traf_train[indices].astype(np.int32).copy()
    y_duration_client  = y_dur_train[indices].astype(np.int32).copy()
    y_bandwidth_client = y_bw_train[indices].astype(np.int32).copy()

    # Package
    client_data_dict = {
        'traffic':   (X_traffic_client,   y_traffic_client),
        'duration':  (X_duration_client,  y_duration_client),
        'bandwidth': (X_bandwidth_client, y_bandwidth_client)
    }

    # Cluster ID lookup
    cluster_id = client_index_to_cluster[i]

    # Create client object
    clients.append(ClientData(client_data_dict, cluster_id))


# Diagnostics
print("\nClient data structures created")
print(f" Total clients: {len(clients)}")
print(f" Clusters: {CFG['n_clusters']}")

print("\nClient 0 data shapes:")
print(f" Traffic:   X={clients[0].ds['traffic'][0].shape},   y={clients[0].ds['traffic'][1].shape}")
print(f" Duration:  X={clients[0].ds['duration'][0].shape},  y={clients[0].ds['duration'][1].shape}")
print(f" Bandwidth: X={clients[0].ds['bandwidth'][0].shape}, y={clients[0].ds['bandwidth'][1].shape}")



Client data structures created
 Total clients: 600
 Clusters: 3

Client 0 data shapes:
 Traffic:   X=(50, 39),   y=(50,)
 Duration:  X=(50, 39),  y=(50,)
 Bandwidth: X=(50, 39), y=(50,)


## 14. Test Data Preparation

In [14]:
test_data = {
 'traffic': (X_traffic_test.astype(np.float32), y_traf_test.astype(int)),
 'duration': (X_duration_test.astype(np.float32), y_dur_test.astype(int)),
 'bandwidth': (X_bandwidth_test.astype(np.float32), y_bw_test.astype(int))
}

print("\n Test data prepared")
print(f" Traffic: {test_data['traffic'][0].shape}")
print(f" Duration: {test_data['duration'][0].shape}")
print(f" Bandwidth: {test_data['bandwidth'][0].shape}")


 Test data prepared
 Traffic: (2500, 39)
 Duration: (2500, 39)
 Bandwidth: (2500, 39)


In [28]:
# Save preprocessed test data for PyBullet simulation inference
import pickle
import os

# Create test data dict with all preprocessed arrays
preprocessed_test_data = {
    'X_traffic': X_traffic_test.astype(np.float32),
    'X_duration': X_duration_test.astype(np.float32),
    'X_bandwidth': X_bandwidth_test.astype(np.float32),
    'y_traffic': y_traf_test.astype(np.int32),
    'y_duration': y_dur_test.astype(np.int32),
    'y_bandwidth': y_bw_test.astype(np.int32),
    'n_samples': len(y_traf_test),
    'input_dim': X_traffic_test.shape[1],
    'n_classes': {
        'traffic': len(np.unique(y_traf_test)),
        'duration': len(np.unique(y_dur_test)),
        'bandwidth': len(np.unique(y_bw_test))
    },
    'traffic_label_encoder_classes': le_traf.classes_.tolist()  # Save label mapping
}

# Save to trained_models directory
save_path = '/Users/sadmanrahin/Documents/gym-pybullet-drones/cesnet_zoo_clean/trained_models/preprocessed_test_data.pkl'
os.makedirs(os.path.dirname(save_path), exist_ok=True)

with open(save_path, 'wb') as f:
    pickle.dump(preprocessed_test_data, f)

print(f"✓ Preprocessed test data saved to: {save_path}")
print(f"  Samples: {preprocessed_test_data['n_samples']}")
print(f"  Input dim: {preprocessed_test_data['input_dim']}")
print(f"  Classes: {preprocessed_test_data['n_classes']}")
print(f"  Traffic labels: {preprocessed_test_data['traffic_label_encoder_classes']}")

✓ Preprocessed test data saved to: /Users/sadmanrahin/Documents/gym-pybullet-drones/cesnet_zoo_clean/trained_models/preprocessed_test_data.pkl
  Samples: 2500
  Input dim: 39
  Classes: {'traffic': 5, 'duration': 5, 'bandwidth': 5}
  Traffic labels: ['discord', 'facebook-web', 'google-services', 'instagram', 'youtube']


## 15. Data Distribution Summary

In [15]:
def print_distribution(labels, name, mapping=None):
    """Print class distribution with optional name mapping."""
    print(f"\n{name} Distribution:")
    unique, counts = np.unique(labels, return_counts=True)

    for u, c in zip(unique, counts):
        if mapping:
            label_name = mapping.get(u, f"Class {u}")
            print(f" {label_name} ({u}): {c}")
        else:
            print(f" Class {u}: {c}")


duration_map = {
    0: "Very Short",
    1: "Short",
    2: "Medium",
    3: "Long",
    4: "Very Long"
}

bandwidth_map = {
    0: "Very Low",
    1: "Low",
    2: "Medium",
    3: "High",
    4: "Very High"
}

print("DATA DISTRIBUTION SUMMARY")

print("\nUnique Classes:")
print(f" Traffic classes: {len(np.unique(y_traf_train))}")
print(f" Duration classes: {len(np.unique(y_dur_train))}")
print(f" Bandwidth classes: {len(np.unique(y_bw_train))}")

print_distribution(y_dur_train, "Duration (Train)", duration_map)
print_distribution(y_bw_train, "Bandwidth (Train)", bandwidth_map)
print_distribution(y_traf_train, "Traffic (Train)")


DATA DISTRIBUTION SUMMARY

Unique Classes:
 Traffic classes: 5
 Duration classes: 5
 Bandwidth classes: 5

Duration (Train) Distribution:
 Very Short (0): 2000
 Short (1): 2000
 Medium (2): 2000
 Long (3): 2000
 Very Long (4): 2000

Bandwidth (Train) Distribution:
 Very Low (0): 2000
 Low (1): 2000
 Medium (2): 2000
 High (3): 2000
 Very High (4): 2000

Traffic (Train) Distribution:
 Class 0: 1999
 Class 1: 1991
 Class 2: 2006
 Class 3: 1984
 Class 4: 2020


## 16. Model Architecture (FedMTL)

In [71]:
class FedMTLModel(keras.Model):
    """
    Federated Multi-Task Learning Model

    Architecture:
    - Shared layers: 2 dense layers (256 → 128) with dropout
    - Task-specific layers: 1 dense layer per task
    - Task heads: 3 classification heads (traffic, duration, bandwidth)
    """
    def __init__(self, in_dims, n_classes, dropout=0.1):
        super().__init__()
        self.tasks = ['traffic', 'duration', 'bandwidth']

        # Shared layers (learned across all tasks)
        self.shared_dense1 = keras.layers.Dense(256, activation='relu', name='shared_dense1')
        self.shared_drop1  = keras.layers.Dropout(dropout)
        self.shared_dense2 = keras.layers.Dense(128, activation='relu', name='shared_dense2')
        self.shared_drop2  = keras.layers.Dropout(dropout)

        # Task-specific layers
        self.task_dense = {
            'traffic':   keras.layers.Dense(6, activation='relu', name='task_traffic_dense'),
            'duration':  keras.layers.Dense(32, activation='relu', name='task_duration_dense'),
            'bandwidth': keras.layers.Dense(64, activation='relu', name='task_bandwidth_dense'),
        }

        # Task heads (output logits)
        self.task_heads = {
            'traffic':   keras.layers.Dense(n_classes['traffic'],   name='traffic_output'),
            'duration':  keras.layers.Dense(n_classes['duration'],  name='duration_output'),
            'bandwidth': keras.layers.Dense(n_classes['bandwidth'], name='bandwidth_output'),
        }

    def call(self, x, task, training=False):
        """Forward pass for a specific task"""
        # Shared layers
        x = self.shared_dense1(x)
        x = self.shared_drop1(x, training=training)
        x = self.shared_dense2(x)
        x = self.shared_drop2(x, training=training)

        # Task-specific branch
        x = self.task_dense[task](x)

        # Final classification head
        return self.task_heads[task](x)

    def build_all(self, input_dim):
        """Build all task heads with a dummy forward pass"""
        tf.random.set_seed(seed)
        dummy = tf.random.normal((1, input_dim))

        for task in self.tasks:
            _ = self.call(dummy, task=task, training=False)

        self.built = True


# Register in Keras custom objects
tf.keras.utils.get_custom_objects().update({'FedMTLModel': FedMTLModel})


FedMTLModel architecture defined
 Architecture:
 Shared: Input → Dense(256) → Dropout → Dense(128) → Dropout
 Traffic:   → Dense(64) → Dense(n_classes)
 Duration:  → Dense(32) → Dense(n_classes)
 Bandwidth: → Dense(64) → Dense(n_classes)


## 17. Flower Client Implementation

In [72]:
class MTLFlowerClient(fl.client.NumPyClient):
    """
    Flower client for Multi-Task Learning

    Handles:
    - Local training on multiple tasks
    - Parameter synchronization with server
    - Task-specific evaluation
    """

    def __init__(self, model, client_data, cfg, cluster_id):
        self.model = model
        self.client_data = client_data 
        self.cfg = cfg
        self.cluster_id = cluster_id

        # Optimizer
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.cfg['lr'])

        # Loss functions (all classification)
        self.loss_fns = {
            'traffic': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            'duration': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            'bandwidth': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        }

        # Task-specific loss weights
        self.loss_weights = cfg['loss_weights']

    # -------- Utility --------
    def _ensure_model_built(self):
        """Make sure the Keras model is built before use."""
        if self.model.built:
            return

        # Try to build from local data
        x = None
        for task in ['traffic', 'duration', 'bandwidth']:
            if task in self.client_data:
                X_task, _ = self.client_data[task]
                if len(X_task) > 0:
                    x = tf.convert_to_tensor(X_task[:1], dtype=tf.float32)
                    break

        # If this client has absolutely no data, fall back to cfg['max_dim'] if available
        if x is None:
            if 'max_dim' in self.cfg:
                input_dim = self.cfg['max_dim']
            else:
                # Try to infer from any task across this client
                all_dims = [
                    v[0].shape[1] for v in self.client_data.values()
                    if v[0].shape[0] > 0
                ]
                input_dim = all_dims[0] if all_dims else 1
            x = tf.random.normal((1, input_dim))

        for t in ['traffic', 'duration', 'bandwidth']:
            _ = self.model(x, task=t, training=False)
        self.model.built = True

    # -------- Flower API --------
    def get_parameters(self, config):
        """Return current model weights"""
        self._ensure_model_built()
        return self.model.get_weights()

    def fit(self, parameters, config):
        """Local training on client data"""
        self._ensure_model_built()

        # Set global weights
        self.model.set_weights(parameters)

        # Local training loop
        for epoch in range(self.cfg['local_epochs']):
            with tf.GradientTape() as tape:
                total_loss = 0.0
                used_tasks = []

                # Loop through all available tasks
                for task in ['traffic', 'duration', 'bandwidth']:
                    if task not in self.client_data:
                        continue

                    X_task, y_task = self.client_data[task]
                    if len(X_task) == 0:
                        continue

                    # Convert data to tensors
                    X_task_tf = tf.convert_to_tensor(X_task, dtype=tf.float32)
                    y_task_tf = tf.convert_to_tensor(y_task, dtype=tf.int32)

                    # Forward pass
                    logits = self.model(X_task_tf, task=task, training=True)

                    # Compute loss and apply task weight
                    task_loss = self.loss_fns[task](y_task_tf, logits)
                    weighted_loss = task_loss * self.loss_weights[task]

                    total_loss += weighted_loss
                    used_tasks.append(task)

                if len(used_tasks) > 0:
                    # Normalize by sum of weights of tasks that are actually present
                    norm = sum(self.loss_weights[t] for t in used_tasks)
                    total_loss = total_loss / norm

                    # Apply gradients
                    grads = tape.gradient(total_loss, self.model.trainable_weights)
                    if grads is not None and any(g is not None for g in grads):
                        self.optimizer.apply_gradients(
                            zip(grads, self.model.trainable_weights)
                        )
                else:
                    total_loss = tf.constant(0.0, dtype=tf.float32)

        # Return updated weights and metadata
        num_examples = sum(len(data[1]) for data in self.client_data.values())
        avg_loss = float(total_loss.numpy()) if isinstance(total_loss, tf.Tensor) else float(total_loss)

        return self.model.get_weights(), num_examples, {
            "loss": avg_loss,
            "num_tasks": len(self.client_data),
            "cluster_id": self.cluster_id,
            "num_examples": num_examples
        }

    def evaluate(self, parameters, config):
        """Evaluate model on client data"""
        self._ensure_model_built()
        self.model.set_weights(parameters)

        total_loss = 0.0
        total_samples = 0
        task_accuracies = {}
        used_tasks = []

        for task in ['traffic', 'duration', 'bandwidth']:
            if task not in self.client_data:
                continue

            X_task, y_task = self.client_data[task]
            if len(X_task) == 0:
                continue

            X_task_tf = tf.convert_to_tensor(X_task, dtype=tf.float32)
            y_task_tf = tf.convert_to_tensor(y_task, dtype=tf.int32)

            # Forward pass
            logits = self.model(X_task_tf, task=task, training=False)

            # Compute loss and apply weights
            task_loss = self.loss_fns[task](y_task_tf, logits)
            weighted_loss = task_loss * self.loss_weights[task]
            total_loss += weighted_loss

            # Classification evaluation
            predictions = tf.argmax(logits, axis=1)
            accuracy = tf.reduce_mean(
                tf.cast(
                    tf.equal(predictions, tf.cast(y_task_tf, tf.int64)),
                    tf.float32,
                )
            )
            task_accuracies[f"{task}_accuracy"] = float(accuracy)
            task_accuracies[f"{task}_loss"] = float(task_loss)

            total_samples += len(y_task)
            used_tasks.append(task)

        if len(used_tasks) > 0:
            norm = sum(self.loss_weights[t] for t in used_tasks)
            avg_loss = float(total_loss / norm)
            overall_accuracy = np.mean([
                task_accuracies[f"{task}_accuracy"]
                for task in used_tasks
            ])
        else:
            avg_loss = 0.0
            overall_accuracy = 0.0

        task_accuracies["accuracy"] = overall_accuracy

        return float(avg_loss), int(total_samples), task_accuracies


print("\nMTLFlowerClient implementation complete")
print("Features:")
print(" - Multi-task local training")
print(" - Weighted loss aggregation over active tasks")
print(" - Task-specific evaluation")
print(" - Parameter change tracking")



MTLFlowerClient implementation complete
Features:
 - Multi-task local training
 - Weighted loss aggregation over active tasks
 - Task-specific evaluation
 - Parameter change tracking


In [73]:
in_dims = {
 'traffic': max_dim,
 'duration': max_dim,
 'bandwidth': max_dim 
}

n_classes = {
 'traffic': len(np.unique(y_traf_train)),
 'duration': len(np.unique(y_dur_train)),
 'bandwidth': len(np.unique(y_bw_train))
}

print("MODEL CONFIGURATION SUMMARY")
print(f"\nInput dimensions:")
for task, dim in in_dims.items():
 print(f" {task}: {dim}")
print(f"\nNumber of classes:")
for task, n in n_classes.items():
 print(f" {task}: {n}")

print(f"\nTraining configuration:")
print(f" Local epochs: {CFG['local_epochs']}")
print(f" Learning rate: {CFG['lr']}")
print(f" Loss weights: {CFG['loss_weights']}")
print(f" Client participation: {CFG['client_frac']*100}%")

print(f"\nFederation structure:")
print(f" Total clients: {CFG['n_clients_flat']}")
print(f" Number of clusters: {CFG['n_clusters']}")
print(f" Clients per cluster: {CFG['clients_per_cluster']}")
print(f" Global aggregator: Cluster {CFG['global_aggregator_cluster']}")
print(f" Split type: {CFG['cluster_split']}")

MODEL CONFIGURATION SUMMARY

Input dimensions:
 traffic: 39
 duration: 39
 bandwidth: 39

Number of classes:
 traffic: 5
 duration: 5
 bandwidth: 5

Training configuration:
 Local epochs: 1
 Learning rate: 0.001
 Loss weights: {'traffic': 1, 'duration': 1, 'bandwidth': 1}
 Client participation: 100.0%

Federation structure:
 Total clients: 600
 Number of clusters: 3
 Clients per cluster: 200
 Global aggregator: Cluster 1
 Split type: equal


## Setup Verification Complete

### Key Verifications:

**1. Model Architecture** 
- Shared layers: 256 → 128 with dropout (0.1)
- Task-specific layers: Traffic(64), Duration(32), Bandwidth(64)
- All tasks use classification heads with sparse categorical crossentropy
- Model properly handles task-specific forward passes

**2. Data Partitioning** 
- 600 clients across 3 clusters (200 clients each)
- Supports both Equal and Dirichlet splits
- Dirichlet α = 0.4 creates moderate non-IID distribution
- Minimum 50 samples per client enforced
- Sequential cluster assignment:
 - Clients 0-199 → Cluster 0
 - Clients 200-399 → Cluster 1 (Global Aggregator)
 - Clients 400-599 → Cluster 2

**3. Data Preprocessing** 
- Winsorization applied at 1st-99th percentile
- StandardScaler normalization on all features
- Quantile-based binning (5 classes) for duration/bandwidth using log1p transform
- Label encoding for traffic classification (5 classes)
- Train/test split: 80/20

**4. Feature Engineering** 
- High-leakage features removed (TCP flags, headers, idle/active times)
- Identity features excluded (IPs, ports)
- Protocol excluded (QUIC-only dataset)
- Feature padding ensures uniform 39-dimensional input
- Mutual information analysis identifies remaining correlated features

**5. Client Implementation** 
- Multi-task training with weighted loss aggregation
- Gradients computed via TensorFlow GradientTape
- Adam optimizer with lr=1e-3
- Parameter change tracking for debugging
- Cluster-aware metadata included in responses

**6. Data Quality Checks** 
- Balanced classes in all three tasks (5 classes each, ~2000 samples per class)
- Test set: 2,500 samples (20%)
- Training set: 10,000 samples (80%) distributed across 600 clients
- Average ~50 samples per client (range: 48-85)

### Ready for:
- Multi-cluster hierarchical federated learning experiments
- Cluster Head (CH) compromisation scenarios
- Communication cost tracking and analysis
- Byzantine attack simulations
- Model poisoning detection
- Privacy-preserving mechanisms

# HIERARCHICAL FEDERATED LEARNING IMPLEMENTATION

## 23. Shutdown Ray (if running)

In [74]:
# Shutdown Ray to clear all workers and memory
if ray.is_initialized():
 ray.shutdown()
 print(" Ray shutdown complete")
else:
 print(" Ray not running")

 Ray shutdown complete


# CH COMPROMISATION EXPERIMENTS

## Test Plan:
1. **Baseline (100 rounds)**: Normal training to convergence
2. **CH Compromise After Convergence**: Train 100 rounds → Compromise CH → Continue 25 rounds (total 125)
3. **Transient CH Compromise**: Compromise CH during training (125 rounds total)

All tests use the same hierarchical architecture with 3 clusters and CH1 as global aggregator.

## 27c. Training-Only Strategies (Save Models, No Testing)


In [75]:
import pickle

class TrainingOnlyStrategy(fl.server.strategy.FedAvg):
    """
    Training strategy that saves model params every round
    AND evaluates on clients to track training accuracy
    """
    def __init__(self, save_dir='trained_models', *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_dir = save_dir
        self.saved_models = []
        os.makedirs(save_dir, exist_ok=True)
        
    def aggregate_fit(self, server_round, results, failures):
        if not results:
            return None, {}
        
        # Standard FedAvg aggregation
        aggregated_params, metrics = super().aggregate_fit(server_round, results, failures)
        
        # Save model params after every round
        model_weights = fl.common.parameters_to_ndarrays(aggregated_params)
        save_path = os.path.join(self.save_dir, f'model_round_{server_round}.pkl')
        
        with open(save_path, 'wb') as f:
            pickle.dump({
                'round': server_round,
                'weights': model_weights,
                'metrics': metrics
            }, f)
        
        self.saved_models.append(save_path)
        
        # Print training progress
        if metrics:
            avg_loss = metrics.get('loss', 0.0)
            print(f"[Round {server_round:3d}] Training Loss: {avg_loss:.4f} | Model saved")
        elif server_round % 20 == 0:
            print(f"[Round {server_round}] Model saved: {save_path}")
        
        return aggregated_params, metrics
    
    def aggregate_evaluate(self, server_round, results, failures):
        """Aggregate evaluation results and print metrics"""
        if not results:
            return None, {}
        
        # Aggregate metrics
        total_loss = 0.0
        total_accuracy = 0.0
        total_samples = 0
        
        # Per-task accuracy aggregation
        task_totals = {
            'traffic_accuracy': 0.0,
            'duration_accuracy': 0.0,
            'bandwidth_accuracy': 0.0
        }
        
        for _, eval_res in results:
            num_examples = eval_res.num_examples
            total_loss += eval_res.loss * num_examples
            if 'accuracy' in eval_res.metrics:
                total_accuracy += eval_res.metrics['accuracy'] * num_examples
            
            # Aggregate per-task accuracies
            for task in ['traffic', 'duration', 'bandwidth']:
                task_key = f'{task}_accuracy'
                if task_key in eval_res.metrics:
                    task_totals[task_key] += eval_res.metrics[task_key] * num_examples
            
            total_samples += num_examples
        
        avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
        avg_accuracy = total_accuracy / total_samples if total_samples > 0 else 0.0
        
        # Calculate per-task averages
        aggregated_metrics = {'accuracy': avg_accuracy, 'loss': avg_loss}
        for task_key in task_totals:
            aggregated_metrics[task_key] = task_totals[task_key] / total_samples if total_samples > 0 else 0.0
        
        # Print evaluation metrics
        if server_round % 10 == 0 or server_round == 1:
            print(f"[Round {server_round:3d}] Eval - Traffic: {aggregated_metrics['traffic_accuracy']:.4f}, "
                  f"Duration: {aggregated_metrics['duration_accuracy']:.4f}, "
                  f"Bandwidth: {aggregated_metrics['bandwidth_accuracy']:.4f}")
        
        return avg_loss, aggregated_metrics

class HierarchicalTrainingOnlyStrategy(fl.server.strategy.FedAvg):
    """
    Hierarchical training-only strategy (saves models, no testing)
    """
    def __init__(self, save_dir='trained_models', *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_dir = save_dir
        self.saved_models = []
        self.global_aggregator_cluster = CFG['global_aggregator_cluster']
        os.makedirs(save_dir, exist_ok=True)
        
    def _ndarrays_weighted_average(self, param_list):
        if not param_list:
            return None
        total_weight = float(sum(w for _, w in param_list))
        if total_weight <= 0:
            total_weight = 1.0
        summed = [np.zeros_like(arr, dtype=arr.dtype) for arr in param_list[0][0]]
        for arrays, w in param_list:
            for i, arr in enumerate(arrays):
                summed[i] = summed[i] + (arr * (w / total_weight))
        return summed
    
    def aggregate_fit(self, server_round, results, failures):
        if not results:
            return None, {}
        
        # Tier 1: Aggregate within clusters
        cluster_to_pairs = {}
        for client_proxy, fit_res in results:
            nds = fl.common.parameters_to_ndarrays(fit_res.parameters)
            weight = getattr(fit_res, 'num_examples', None)
            if weight is None:
                weight = int(fit_res.metrics.get('num_examples', 1)) if hasattr(fit_res, 'metrics') else 1
            cluster_id = int(fit_res.metrics.get('cluster_id', 0)) if hasattr(fit_res, 'metrics') else 0
            cluster_to_pairs.setdefault(cluster_id, []).append((nds, weight))
        
        cluster_params = {}
        cluster_weights = {}
        
        for cid, pairs in cluster_to_pairs.items():
            if pairs:
                cluster_params[cid] = self._ndarrays_weighted_average(pairs)
                cluster_weights[cid] = float(sum(w for _, w in pairs))
        
        # Tier 2: Global aggregation at CH1
        global_agg_cluster = self.global_aggregator_cluster
        
        if global_agg_cluster in cluster_params:
            global_pairs = []
            for cid in [0, 2]:
                if cid in cluster_params:
                    global_pairs.append((cluster_params[cid], cluster_weights[cid]))
            
            if global_agg_cluster in cluster_params:
                global_pairs.append((cluster_params[global_agg_cluster], cluster_weights[global_agg_cluster]))
            
            if global_pairs:
                global_params = self._ndarrays_weighted_average(global_pairs)
                aggregated_params = fl.common.ndarrays_to_parameters(global_params)
            else:
                aggregated_params = fl.common.ndarrays_to_parameters(cluster_params[global_agg_cluster])
        else:
            all_pairs = [(cluster_params[cid], cluster_weights[cid]) for cid in cluster_params.keys()]
            if all_pairs:
                global_params = self._ndarrays_weighted_average(all_pairs)
                aggregated_params = fl.common.ndarrays_to_parameters(global_params)
            else:
                return None, {}
        
        # Save model params after every round
        model_weights = fl.common.parameters_to_ndarrays(aggregated_params)
        save_path = os.path.join(self.save_dir, f'model_round_{server_round}.pkl')
        
        with open(save_path, 'wb') as f:
            pickle.dump({
                'round': server_round,
                'weights': model_weights,
                'cluster_params': {cid: params for cid, params in cluster_params.items()},
                'metrics': {'participating_clusters': len(cluster_params)}
            }, f)
        
        self.saved_models.append(save_path)
        
        # Print training progress
        print(f"[Round {server_round:3d}] Clusters: {len(cluster_params)} | Model saved")
        
        return aggregated_params, {}
    
    def aggregate_evaluate(self, server_round, results, failures):
        """Aggregate evaluation results and print metrics"""
        if not results:
            return None, {}
        
        # Aggregate metrics
        total_loss = 0.0
        total_accuracy = 0.0
        total_samples = 0
        
        # Per-task accuracy aggregation
        task_totals = {
            'traffic_accuracy': 0.0,
            'duration_accuracy': 0.0,
            'bandwidth_accuracy': 0.0
        }
        
        for _, eval_res in results:
            num_examples = eval_res.num_examples
            total_loss += eval_res.loss * num_examples
            if 'accuracy' in eval_res.metrics:
                total_accuracy += eval_res.metrics['accuracy'] * num_examples
            
            # Aggregate per-task accuracies
            for task in ['traffic', 'duration', 'bandwidth']:
                task_key = f'{task}_accuracy'
                if task_key in eval_res.metrics:
                    task_totals[task_key] += eval_res.metrics[task_key] * num_examples
            
            total_samples += num_examples
        
        avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
        avg_accuracy = total_accuracy / total_samples if total_samples > 0 else 0.0
        
        # Calculate per-task averages
        aggregated_metrics = {'accuracy': avg_accuracy, 'loss': avg_loss}
        for task_key in task_totals:
            aggregated_metrics[task_key] = task_totals[task_key] / total_samples if total_samples > 0 else 0.0
        
        # Print evaluation metrics
        if server_round % 10 == 0 or server_round == 1:
            print(f"[Round {server_round:3d}] Eval - Traffic: {aggregated_metrics['traffic_accuracy']:.4f}, "
                  f"Duration: {aggregated_metrics['duration_accuracy']:.4f}, "
                  f"Bandwidth: {aggregated_metrics['bandwidth_accuracy']:.4f}")
        
        return avg_loss, aggregated_metrics

print("Training strategies ready (save models + track per-task accuracy during training)")


Training strategies ready (save models + track per-task accuracy during training)


## PHASE 1: TRAINING (Save Models Every Round)


In [None]:
TRAINING_ROUNDS = 125

# Shutdown Ray
if ray.is_initialized():
    ray.shutdown()

print("="*80)
print("TRAINING PHASE: SAVING MODELS FOR LATER TESTING")
print("="*80)
print(f"\nTraining Configuration:")
print(f"  Rounds: {TRAINING_ROUNDS}")
print(f"  Save: Model params after EVERY round")
print(f"  Testing: Done separately after training")
print(f"\n  Cluster-level: EQUAL split (for training)")
print(f"  Client-level: Dirichlet split (alpha={CFG['alpha_client']})")
print("="*80)

# Build client partitions with EQUAL cluster split for training
print("\nBuilding client partitions for training (equal cluster split)...")
client_indices_flat, client_index_to_cluster = build_client_partitions(
    cluster_split='equal',
    client_split='dirichlet',
    verbose=True
)

# Create client objects
clients = []
for i, indices in enumerate(client_indices_flat):
    X_traffic_client = X_traffic_train[indices]
    X_duration_client = X_duration_train[indices]
    X_bandwidth_client = X_bandwidth_train[indices]
    
    y_traffic_client = y_traf_train[indices]
    y_duration_client = y_dur_train[indices]
    y_bandwidth_client = y_bw_train[indices]
    
    client_data_dict = {
        'traffic': (X_traffic_client.astype(np.float32), y_traffic_client),
        'duration': (X_duration_client.astype(np.float32), y_duration_client),
        'bandwidth': (X_bandwidth_client.astype(np.float32), y_bandwidth_client)
    }
    
    cluster_id = client_index_to_cluster[i]
    clients.append(ClientData(client_data_dict, cluster_id))

print(f" Clients created: {len(clients)}")

# Prepare for training
max_dim = max(in_dims.values())
in_dims_uniform = {
    'traffic': max_dim,
    'duration': max_dim,
    'bandwidth': max_dim
}

def client_fn(context: fl.common.Context) -> fl.client.Client:
    tf.random.set_seed(seed)
    client_id = hash(context.node_id) % len(clients)
    client_obj = clients[client_id]
    client_data = client_obj.ds
    cluster_id = client_obj.cluster_id
    
    model = FedMTLModel(in_dims_uniform, n_classes, dropout=0.1)
    model.build_all(max_dim)
    
    numpy_client = MTLFlowerClient(model, client_data, CFG, cluster_id)
    return numpy_client.to_client()

# Create global model template
global_model_template = FedMTLModel(in_dims_uniform, n_classes, dropout=0.1)
global_model_template.build_all(max_dim)

def aggregate_metrics(metrics):
    aggregated = {}
    for num_examples, client_metrics in metrics:
        for metric_name, metric_value in client_metrics.items():
            if metric_name not in aggregated:
                aggregated[metric_name] = []
            aggregated[metric_name].append(metric_value)
    for metric_name in aggregated:
        aggregated[metric_name] = np.mean(aggregated[metric_name])
    return aggregated

print("\n" + "="*80)
print("TRAINING 1/2: SINGLE CLUSTER")
print("="*80)

# Single cluster training
strategy_single = TrainingOnlyStrategy(
    save_dir='trained_models/single_cluster',
    fraction_fit=CFG['client_frac'],
    fraction_evaluate=CFG['client_frac'],
    min_fit_clients=10,
    min_available_clients=10,
    initial_parameters=fl.common.ndarrays_to_parameters(global_model_template.get_weights()),
    fit_metrics_aggregation_fn=aggregate_metrics,
    evaluate_metrics_aggregation_fn=aggregate_metrics
)

history_single = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=len(clients),
    config=fl.server.ServerConfig(num_rounds=TRAINING_ROUNDS),
    strategy=strategy_single,
    client_resources={'num_cpus': 1.0, 'num_gpus': 0.0},
)

print(f"\nSingle cluster training complete!")
print(f"  Saved {len(strategy_single.saved_models)} models")

# Shutdown Ray
if ray.is_initialized():
    ray.shutdown()

print("\n" + "="*80)
print("TRAINING 2/2: THREE CLUSTER HIERARCHICAL")
print("="*80)

# Hierarchical training
strategy_hierarchical = HierarchicalTrainingOnlyStrategy(
    save_dir='trained_models/hierarchical_equal',
    fraction_fit=CFG['client_frac'],
    fraction_evaluate=CFG['client_frac'],
    min_fit_clients=10,
    min_available_clients=10,
    initial_parameters=fl.common.ndarrays_to_parameters(global_model_template.get_weights()),
    fit_metrics_aggregation_fn=aggregate_metrics,
    evaluate_metrics_aggregation_fn=aggregate_metrics
)

history_hierarchical = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=len(clients),
    config=fl.server.ServerConfig(num_rounds=TRAINING_ROUNDS),
    strategy=strategy_hierarchical,
    client_resources={'num_cpus': 1.0, 'num_gpus': 0.0},
)

print(f"\nHierarchical training complete!")
print(f"  Saved {len(strategy_hierarchical.saved_models)} models")

print("\n" + "="*80)
print("TRAINING PHASE COMPLETE!")
print("="*80)
print(f"\nTotal models saved:")
print(f"  Single cluster: {len(strategy_single.saved_models)} models")
print(f"  Hierarchical: {len(strategy_hierarchical.saved_models)} models")
print(f"  TOTAL: {len(strategy_single.saved_models) + len(strategy_hierarchical.saved_models)} models")
print(f"\nReady for testing phase!")


TRAINING PHASE: SAVING MODELS FOR LATER TESTING

Training Configuration:
  Rounds: 125
  Save: Model params after EVERY round
  Testing: Done separately after training

  Cluster-level: EQUAL split (for training)
  Client-level: Dirichlet split (alpha=0.4)

Building client partitions for training (equal cluster split)...

✓ Created 600 clients
  Cluster split: equal
  Client split: dirichlet
  Sample sizes: min=50, max=92, avg=50.2

  Cluster distribution:
    Cluster 0: 200 clients, 10016 samples
    Cluster 1: 200 clients, 10067 samples
    Cluster 2: 200 clients, 10061 samples

  Sample client label distributions:
    Client 0 (Cluster 0): {0: 16, 1: 9, 2: 9, 3: 7, 4: 9}
    Client 1 (Cluster 0): {0: 16, 1: 11, 2: 8, 3: 9, 4: 6}
    Client 2 (Cluster 0): {0: 8, 1: 12, 2: 5, 3: 13, 4: 12}
 Clients created: 600


	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=125, no round_timeout



TRAINING 1/2: SINGLE CLUSTER


2025-12-01 21:05:16,727	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 8.0, 'memory': 8194755789.0, 'node:127.0.0.1': 1.0, 'object_store_memory': 2147483648.0, 'node:__internal_head__': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1.0, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 8 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 resul

[Round   1] Training Loss: 1.6322 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)


[Round   1] Eval - Traffic: 0.2579, Duration: 0.2558, Bandwidth: 0.1849


[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round   2] Training Loss: 1.5866 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round   3] Training Loss: 1.5515 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round   4] Training Loss: 1.5214 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round   5] Training Loss: 1.4950 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round   6] Training Loss: 1.4711 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round   7] Training Loss: 1.4477 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round   8] Training Loss: 1.4251 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round   9] Training Loss: 1.4036 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  10] Training Loss: 1.3818 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 11]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)


[Round  10] Eval - Traffic: 0.5389, Duration: 0.5292, Bandwidth: 0.5156


[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  11] Training Loss: 1.3624 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 12]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  12] Training Loss: 1.3431 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 13]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  13] Training Loss: 1.3237 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 14]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  14] Training Loss: 1.3056 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 15]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  15] Training Loss: 1.2870 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 16]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  16] Training Loss: 1.2696 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 17]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  17] Training Loss: 1.2526 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 18]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  18] Training Loss: 1.2363 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 19]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  19] Training Loss: 1.2201 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 20]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  20] Training Loss: 1.2043 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 21]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)


[Round  20] Eval - Traffic: 0.5852, Duration: 0.5736, Bandwidth: 0.5974


[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  21] Training Loss: 1.1901 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 22]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  22] Training Loss: 1.1746 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 23]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  23] Training Loss: 1.1600 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 24]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  24] Training Loss: 1.1457 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 25]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  25] Training Loss: 1.1324 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 26]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  26] Training Loss: 1.1192 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 27]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  27] Training Loss: 1.1057 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 28]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  28] Training Loss: 1.0921 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 29]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  29] Training Loss: 1.0811 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 30]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  30] Training Loss: 1.0678 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 31]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)


[Round  30] Eval - Traffic: 0.6120, Duration: 0.6101, Bandwidth: 0.6679


[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  31] Training Loss: 1.0559 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 32]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  32] Training Loss: 1.0428 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 33]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  33] Training Loss: 1.0325 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 34]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  34] Training Loss: 1.0206 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 35]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  35] Training Loss: 1.0095 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 36]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  36] Training Loss: 0.9986 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 37]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  37] Training Loss: 0.9877 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 38]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  38] Training Loss: 0.9779 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 39]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  39] Training Loss: 0.9677 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 40]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  40] Training Loss: 0.9560 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 41]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)


[Round  40] Eval - Traffic: 0.6347, Duration: 0.6429, Bandwidth: 0.7485


[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  41] Training Loss: 0.9461 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 42]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  42] Training Loss: 0.9355 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 43]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  43] Training Loss: 0.9273 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 44]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  44] Training Loss: 0.9172 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 45]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  45] Training Loss: 0.9084 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 46]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  46] Training Loss: 0.8996 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 47]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  47] Training Loss: 0.8902 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 48]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  48] Training Loss: 0.8814 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 49]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  49] Training Loss: 0.8749 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 50]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  50] Training Loss: 0.8680 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 51]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)


[Round  50] Eval - Traffic: 0.6459, Duration: 0.6771, Bandwidth: 0.7911


[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  51] Training Loss: 0.8584 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 52]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  52] Training Loss: 0.8515 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 53]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  53] Training Loss: 0.8446 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 54]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  54] Training Loss: 0.8383 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 55]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  55] Training Loss: 0.8299 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 56]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  56] Training Loss: 0.8234 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 57]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  57] Training Loss: 0.8161 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 58]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  58] Training Loss: 0.8107 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 59]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  59] Training Loss: 0.8034 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 60]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  60] Training Loss: 0.7985 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 61]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)


[Round  60] Eval - Traffic: 0.6545, Duration: 0.7176, Bandwidth: 0.8120


[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  61] Training Loss: 0.7927 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 62]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  62] Training Loss: 0.7874 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 63]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  63] Training Loss: 0.7802 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 64]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
[92mINFO [0m:      aggregate_fit: received 600 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 600 clients (out of 600)


[Round  64] Training Loss: 0.7751 | Model saved


[92mINFO [0m:      aggregate_evaluate: received 600 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 65]
[92mINFO [0m:      configure_fit: strategy sampled 600 clients (out of 600)
python(98635) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


KeyboardInterrupt: 

## Training Accuracy Visualization (from history object)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

print("="*80)
print("TRAINING ACCURACY CURVES (During Training)")
print("="*80)

# Check if training history objects are available
if 'history_single' not in locals() or 'history_hierarchical' not in locals():
    print("⚠️ Training history not available. Run training cells first (cell 45).")
else:
    # Create figure with 2 subplots (single cluster vs hierarchical)
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    experiments = [
        ('history_single', history_single, 'Single Cluster', axes[0]),
        ('history_hierarchical', history_hierarchical, 'Hierarchical (Equal Split)', axes[1])
    ]
    
    for exp_name, history, title, ax in experiments:
        # Check if evaluation metrics are available
        if not hasattr(history, 'metrics_distributed') or not history.metrics_distributed:
            ax.text(0.5, 0.5, f'No evaluation metrics available\nfor {exp_name}', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=12)
            ax.set_title(title, fontsize=12, fontweight='bold')
            continue
        
        metrics = history.metrics_distributed
        
        # Check for per-task accuracy metrics
        required_metrics = ['traffic_accuracy', 'duration_accuracy', 'bandwidth_accuracy']
        missing = [m for m in required_metrics if m not in metrics]
        
        if missing:
            ax.text(0.5, 0.5, f'Missing metrics:\n{", ".join(missing)}', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=10)
            ax.set_title(title, fontsize=12, fontweight='bold')
            continue
        
        # Extract rounds and accuracies
        rounds = [r for r, _ in metrics['traffic_accuracy']]
        traffic_acc = [float(v) for _, v in metrics['traffic_accuracy']]
        duration_acc = [float(v) for _, v in metrics['duration_accuracy']]
        bandwidth_acc = [float(v) for _, v in metrics['bandwidth_accuracy']]
        
        # Plot per-task accuracies
        ax.plot(rounds, traffic_acc, color='green', label='Traffic Classification', 
               linewidth=2, marker='o', markersize=4, markevery=max(1, len(rounds)//10))
        ax.plot(rounds, duration_acc, color='blue', label='Flow Duration Classification', 
               linewidth=2, marker='s', markersize=4, markevery=max(1, len(rounds)//10))
        ax.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth Classification', 
               linewidth=2, marker='^', markersize=4, markevery=max(1, len(rounds)//10))
        
        # Format subplot
        ax.set_xlabel('Rounds', fontsize=11)
        ax.set_ylabel('Accuracy', fontsize=11)
        ax.set_title(title, fontsize=12, fontweight='bold')
        ax.legend(loc='lower right', fontsize=9)
        ax.grid(True, alpha=0.3, linestyle='--')
        
        # Set y-axis limits with padding
        all_acc = traffic_acc + duration_acc + bandwidth_acc
        y_min = max(0.0, min(all_acc) - 0.05)
        y_max = min(1.0, max(all_acc) + 0.05)
        ax.set_ylim(y_min, y_max)
        
        # Print summary
        print(f"\n{title}:")
        print(f"  Rounds: {len(rounds)}")
        print(f"  Final Accuracies:")
        print(f"    Traffic:   {traffic_acc[-1]:.4f} ({traffic_acc[-1]*100:.2f}%)")
        print(f"    Duration:  {duration_acc[-1]:.4f} ({duration_acc[-1]*100:.2f}%)")
        print(f"    Bandwidth: {bandwidth_acc[-1]:.4f} ({bandwidth_acc[-1]*100:.2f}%)")
        print(f"  Best Accuracies:")
        print(f"    Traffic:   {max(traffic_acc):.4f} at Round {rounds[np.argmax(traffic_acc)]}")
        print(f"    Duration:  {max(duration_acc):.4f} at Round {rounds[np.argmax(duration_acc)]}")
        print(f"    Bandwidth: {max(bandwidth_acc):.4f} at Round {rounds[np.argmax(bandwidth_acc)]}")
    
    plt.suptitle('Training Accuracy Curves (Evaluated on Client Data During Training)', 
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()
    
    print("\n" + "="*80)
    print("Note: These accuracies are from evaluating the global model on client data")
    print("during training (not on the separate test set).")
    print("="*80)

## PHASE 2: TESTING (Evaluate Saved Models)


In [None]:
def create_test_data_partitions(cluster_split='equal'):
    """
    Create test data partitions for evaluation
    
    Args:
        cluster_split: 'equal' or 'dirichlet' - how to split test data among clusters
    
    Returns:
        test_data_dict: Dictionary with cluster-level and global test data
    """
    n_clusters = CFG['n_clusters']
    alpha_cluster = CFG['alpha_cluster']
    rng = np.random.default_rng(seed + 1000)  # Different seed for test
    
    test_indices = np.arange(len(y_traf_test))
    labels_test = np.unique(y_traf_test)
    
    if cluster_split == 'equal':
        # Equal split: each cluster gets 1/3 of test data
        samples_per_cluster = len(test_indices) // n_clusters
        cluster_test_indices = []
        for cluster_id in range(n_clusters):
            start_idx = cluster_id * samples_per_cluster
            end_idx = start_idx + samples_per_cluster if cluster_id < n_clusters - 1 else len(test_indices)
            cluster_test_indices.append(test_indices[start_idx:end_idx])
    
    elif cluster_split == 'dirichlet':
        # Dirichlet split: non-IID distribution among clusters
        cluster_bins = [[] for _ in range(n_clusters)]
        label_indices_test = {}
        
        for lbl in labels_test:
            label_indices_test[lbl] = test_indices[y_traf_test == lbl]
        
        for lbl in labels_test:
            idxs = label_indices_test[lbl]
            rng.shuffle(idxs)
            proportions = rng.dirichlet([alpha_cluster] * n_clusters)
            cuts = (np.cumsum(proportions) * len(idxs)).astype(int)
            parts = np.split(idxs, cuts[:-1])
            
            for cluster_id, part in enumerate(parts):
                cluster_bins[cluster_id].extend(part.tolist())
        
        cluster_test_indices = [np.array(sorted(set(cluster_bins[i]))) for i in range(n_clusters)]
    
    else:
        raise ValueError(f"Unknown cluster_split: {cluster_split}")
    
    # Create test data dictionary
    test_data_dict = {
        'global': {
            'X_traffic': X_traffic_test.astype(np.float32),
            'y_traffic': y_traf_test.astype(int),
            'X_duration': X_duration_test.astype(np.float32),
            'y_duration': y_dur_test.astype(int),
            'X_bandwidth': X_bandwidth_test.astype(np.float32),
            'y_bandwidth': y_bw_test.astype(int)
        },
        'clusters': {}
    }
    
    # Add per-cluster test data
    for cluster_id in range(n_clusters):
        indices = cluster_test_indices[cluster_id]
        test_data_dict['clusters'][cluster_id] = {
            'X_traffic': X_traffic_test[indices].astype(np.float32),
            'y_traffic': y_traf_test[indices].astype(int),
            'X_duration': X_duration_test[indices].astype(np.float32),
            'y_duration': y_dur_test[indices].astype(int),
            'X_bandwidth': X_bandwidth_test[indices].astype(np.float32),
            'y_bandwidth': y_bw_test[indices].astype(int),
            'size': len(indices)
        }
    
    print(f"\nTest data partitioned ({cluster_split} split):")
    print(f"  Global: {len(test_indices)} samples")
    for cluster_id in range(n_clusters):
        print(f"  Cluster {cluster_id}: {test_data_dict['clusters'][cluster_id]['size']} samples")
    
    return test_data_dict

def evaluate_model_on_test(model_weights, test_data_dict, model_type='global', cluster_id=None):
    """
    Evaluate saved model on test data
    
    Args:
        model_weights: Saved model weights
        test_data_dict: Test data dictionary
        model_type: 'global' or 'cluster'
        cluster_id: Which cluster to test (if model_type='cluster')
    
    Returns:
        Dictionary of accuracies
    """
    # Create model and load weights
    model = FedMTLModel(in_dims_uniform, n_classes, dropout=0.1)
    model.build_all(max_dim)
    model.set_weights(model_weights)
    
    # Get test data
    if model_type == 'global':
        test_data = test_data_dict['global']
    elif model_type == 'cluster' and cluster_id is not None:
        test_data = test_data_dict['clusters'][cluster_id]
    else:
        raise ValueError("Invalid model_type or missing cluster_id")
    
    # Evaluate on all tasks
    accuracies = {}
    for task in ['traffic', 'duration', 'bandwidth']:
        X_test = test_data[f'X_{task}']
        y_test = test_data[f'y_{task}']
        
        logits = model(X_test, task=task, training=False)
        predictions = tf.argmax(logits, axis=1).numpy()
        
        accuracy = np.mean(predictions == y_test)
        accuracies[f'{task}_accuracy'] = float(accuracy)
    
    accuracies['overall_accuracy'] = np.mean([
        accuracies['traffic_accuracy'],
        accuracies['duration_accuracy'],
        accuracies['bandwidth_accuracy']
    ])
    
    return accuracies

print("Test data partitioning and evaluation functions ready")


## 28. Test Evaluation: Single Cluster & Three Cluster Equal Split


In [None]:
print("="*80)
print("TESTING PHASE: EVALUATE SAVED MODELS")
print("="*80)
print("\nEvaluating first 100 rounds only (convergence period)")
print("Testing with EQUAL cluster split\n")

# Create test data with equal split
test_data_equal = create_test_data_partitions(cluster_split='equal')

# Store results
test_results = {
    'single_cluster': [],
    'hierarchical_equal': []
}

# Evaluate single cluster models (rounds 1-100)
print("\n" + "="*80)
print("EVALUATING: SINGLE CLUSTER MODELS")
print("="*80)

for round_num in range(1, 101):
    model_path = f'trained_models/single_cluster/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Evaluate on global test data
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    test_results['single_cluster'].append(accuracies)
    
    if round_num % 10 == 0 or round_num == 1:
        print(f"[Round {round_num:3d}/100] Traffic: {accuracies['traffic_accuracy']:.4f}, "
              f"Duration: {accuracies['duration_accuracy']:.4f}, "
              f"Bandwidth: {accuracies['bandwidth_accuracy']:.4f}")

print(f"\n✅ Single cluster evaluation complete: {len(test_results['single_cluster'])} rounds")

# Evaluate hierarchical models (rounds 1-100)
print("\n" + "="*80)
print("EVALUATING: HIERARCHICAL MODELS (Equal Cluster Split)")
print("="*80)

for round_num in range(1, 101):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Evaluate on global test data
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    test_results['hierarchical_equal'].append(accuracies)
    
    if round_num % 10 == 0 or round_num == 1:
        print(f"[Round {round_num:3d}/100] Traffic: {accuracies['traffic_accuracy']:.4f}, "
              f"Duration: {accuracies['duration_accuracy']:.4f}, "
              f"Bandwidth: {accuracies['bandwidth_accuracy']:.4f}")

print(f"\n✅ Hierarchical evaluation complete: {len(test_results['hierarchical_equal'])} rounds")

print("\n" + "="*80)
print("EVALUATION COMPLETE")
print("="*80)
print(f"\nResults stored:")
print(f"  test_results['single_cluster']: {len(test_results['single_cluster'])} rounds")
print(f"  test_results['hierarchical_equal']: {len(test_results['hierarchical_equal'])} rounds")


## 29. Test Evaluation: Three Cluster with Dirichlet Split (Per-Cluster Graphs)


In [None]:
print("="*80)
print("TESTING WITH DIRICHLET CLUSTER SPLIT (Per-Cluster Evaluation)")
print("="*80)
print("\nEvaluating hierarchical models on Dirichlet cluster split")
print("Shows per-cluster performance when clusters have non-IID data\n")

# Create test data with dirichlet split
test_data_dirichlet = create_test_data_partitions(cluster_split='dirichlet')

# Store results
test_results['hierarchical_dirichlet_global'] = []
test_results['hierarchical_dirichlet_per_cluster'] = {
    0: [],
    1: [],
    2: []
}

# Evaluate hierarchical models (rounds 1-100)
print("\n" + "="*80)
print("EVALUATING: HIERARCHICAL MODELS (Dirichlet Cluster Split)")
print("="*80)

for round_num in range(1, 101):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Global evaluation
    accuracies_global = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='global')
    accuracies_global['round'] = round_num
    test_results['hierarchical_dirichlet_global'].append(accuracies_global)
    
    # Per-cluster evaluation
    for cluster_id in range(CFG['n_clusters']):
        accuracies_cluster = evaluate_model_on_test(saved['weights'], test_data_dirichlet, 
                                                     model_type='cluster', cluster_id=cluster_id)
        accuracies_cluster['round'] = round_num
        test_results['hierarchical_dirichlet_per_cluster'][cluster_id].append(accuracies_cluster)
    
    if round_num % 10 == 0 or round_num == 1:
        print(f"[Round {round_num:3d}/100] Global - Traffic: {accuracies_global['traffic_accuracy']:.4f}")
        for cluster_id in range(CFG['n_clusters']):
            acc = test_results['hierarchical_dirichlet_per_cluster'][cluster_id][-1]
            print(f"               Cluster {cluster_id} - Traffic: {acc['traffic_accuracy']:.4f}")

print(f"\n✅ Hierarchical Dirichlet evaluation complete:")
print(f"  Global: {len(test_results['hierarchical_dirichlet_global'])} rounds")
print(f"  Per-cluster: {len(test_results['hierarchical_dirichlet_per_cluster'][0])} rounds each")

print("\n" + "="*80)
print("ALL EVALUATIONS COMPLETE")
print("="*80)
print(f"\nTest results available:")
print(f"  test_results['single_cluster']")
print(f"  test_results['hierarchical_equal']")
print(f"  test_results['hierarchical_dirichlet_global']")
print(f"  test_results['hierarchical_dirichlet_per_cluster'][0/1/2]")


## 30. CH Compromise After Convergence (Testing Phase)


In [None]:
print("="*80)
print("CH COMPROMISE AFTER CONVERGENCE (Testing Phase)")
print("="*80)
print("\nScenario: CH0 compromised at round 111, detected at round 112")
print("Uses trained models from rounds 1-100, then extends to 125")
print("Testing with EQUAL and DIRICHLET cluster splits\n")

# Store compromise results (global and per-cluster)
test_results['compromise_after_convergence'] = []
test_results['compromise_after_convergence_per_cluster_equal'] = {0: [], 1: [], 2: []}
test_results['compromise_after_convergence_per_cluster_dirichlet'] = {0: [], 1: [], 2: []}

# Create test data partitions for both splits (if not already created)
if 'test_data_equal' not in globals():
    test_data_equal = create_test_data_partitions(cluster_split='equal')
if 'test_data_dirichlet' not in globals():
    test_data_dirichlet = create_test_data_partitions(cluster_split='dirichlet')

print("Testing rounds 1-110: Normal operation")
for round_num in range(1, 111):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Global evaluation
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    accuracies['phase'] = 'normal'
    test_results['compromise_after_convergence'].append(accuracies)
    
    # Per-cluster evaluation (equal split)
    for cluster_id in range(CFG['n_clusters']):
        acc_cluster = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
        acc_cluster['round'] = round_num
        acc_cluster['phase'] = 'normal'
        test_results['compromise_after_convergence_per_cluster_equal'][cluster_id].append(acc_cluster)
    
    # Per-cluster evaluation (dirichlet split)
    for cluster_id in range(CFG['n_clusters']):
        acc_cluster = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
        acc_cluster['round'] = round_num
        acc_cluster['phase'] = 'normal'
        test_results['compromise_after_convergence_per_cluster_dirichlet'][cluster_id].append(acc_cluster)
    
    if round_num % 10 == 0 or round_num == 1:
        print(f"[Round {round_num:3d}/110] Normal - Traffic: {accuracies['traffic_accuracy']:.4f}")

print("\nRound 111: CH0 compromised (but not detected yet)")
model_path = f'trained_models/hierarchical_equal/model_round_111.pkl'
with open(model_path, 'rb') as f:
    saved = pickle.load(f)

# Global + per-cluster
accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
accuracies['round'] = 111
accuracies['phase'] = 'compromised'
test_results['compromise_after_convergence'].append(accuracies)

for cluster_id in range(CFG['n_clusters']):
    acc_cluster_eq = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
    acc_cluster_eq['round'] = 111
    acc_cluster_eq['phase'] = 'compromised'
    test_results['compromise_after_convergence_per_cluster_equal'][cluster_id].append(acc_cluster_eq)
    
    acc_cluster_dir = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
    acc_cluster_dir['round'] = 111
    acc_cluster_dir['phase'] = 'compromised'
    test_results['compromise_after_convergence_per_cluster_dirichlet'][cluster_id].append(acc_cluster_dir)

print(f"[Round 111] Compromised - Traffic: {accuracies['traffic_accuracy']:.4f}")

print("\nRound 112: Compromise detected, D&R-E phase begins")
print("Rounds 112-118: D&R-E Phase (7 rounds) - CH0 offline, cluster excluded")

for round_num in range(112, 119):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Global + per-cluster
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    accuracies['phase'] = 'dre'
    test_results['compromise_after_convergence'].append(accuracies)
    
    for cluster_id in range(CFG['n_clusters']):
        acc_cluster_eq = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
        acc_cluster_eq['round'] = round_num
        acc_cluster_eq['phase'] = 'dre'
        test_results['compromise_after_convergence_per_cluster_equal'][cluster_id].append(acc_cluster_eq)
        
        acc_cluster_dir = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
        acc_cluster_dir['round'] = round_num
        acc_cluster_dir['phase'] = 'dre'
        test_results['compromise_after_convergence_per_cluster_dirichlet'][cluster_id].append(acc_cluster_dir)
    
    print(f"[Round {round_num}] D&R-E - Traffic: {accuracies['traffic_accuracy']:.4f}")

print("\nRounds 119-121: Continuity Phase (3 rounds) - Gradual re-entry")
continuity_rates = {119: 0.30, 120: 0.70, 121: 1.00}

for round_num in range(119, 122):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Global + per-cluster
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    accuracies['phase'] = 'continuity'
    accuracies['participation_rate'] = continuity_rates[round_num]
    test_results['compromise_after_convergence'].append(accuracies)
    
    for cluster_id in range(CFG['n_clusters']):
        acc_cluster_eq = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
        acc_cluster_eq['round'] = round_num
        acc_cluster_eq['phase'] = 'continuity'
        acc_cluster_eq['participation_rate'] = continuity_rates[round_num]
        test_results['compromise_after_convergence_per_cluster_equal'][cluster_id].append(acc_cluster_eq)
        
        acc_cluster_dir = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
        acc_cluster_dir['round'] = round_num
        acc_cluster_dir['phase'] = 'continuity'
        acc_cluster_dir['participation_rate'] = continuity_rates[round_num]
        test_results['compromise_after_convergence_per_cluster_dirichlet'][cluster_id].append(acc_cluster_dir)
    
    print(f"[Round {round_num}] Continuity ({int(continuity_rates[round_num]*100)}%) - "
          f"Traffic: {accuracies['traffic_accuracy']:.4f}")

print("\nRounds 122-125: Re-stabilization Phase")
for round_num in range(122, 126):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Global + per-cluster
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    accuracies['phase'] = 'stabilization'
    test_results['compromise_after_convergence'].append(accuracies)
    
    for cluster_id in range(CFG['n_clusters']):
        acc_cluster_eq = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
        acc_cluster_eq['round'] = round_num
        acc_cluster_eq['phase'] = 'stabilization'
        test_results['compromise_after_convergence_per_cluster_equal'][cluster_id].append(acc_cluster_eq)
        
        acc_cluster_dir = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
        acc_cluster_dir['round'] = round_num
        acc_cluster_dir['phase'] = 'stabilization'
        test_results['compromise_after_convergence_per_cluster_dirichlet'][cluster_id].append(acc_cluster_dir)
    
    print(f"[Round {round_num}] Stabilization - Traffic: {accuracies['traffic_accuracy']:.4f}")

print(f"\n✅ CH Compromise After Convergence complete: {len(test_results['compromise_after_convergence'])} rounds")


## 31. Transient CH Compromise (Testing Phase)


In [None]:
print("="*80)
print("TRANSIENT CH COMPROMISE (Testing Phase)")
print("="*80)
print("\nScenario: CH0 compromised at round 10, detected at round 11")
print("Testing 30 rounds total with early compromise")
print("Testing with EQUAL and DIRICHLET cluster splits\n")

# Store transient results (global and per-cluster)
test_results['transient_compromise'] = []
test_results['transient_compromise_per_cluster_equal'] = {0: [], 1: [], 2: []}
test_results['transient_compromise_per_cluster_dirichlet'] = {0: [], 1: [], 2: []}

# Create test data partitions (if not already created)
if 'test_data_equal' not in globals():
    test_data_equal = create_test_data_partitions(cluster_split='equal')
if 'test_data_dirichlet' not in globals():
    test_data_dirichlet = create_test_data_partitions(cluster_split='dirichlet')

print("Testing rounds 1-9: Normal operation")
for round_num in range(1, 10):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Global + per-cluster
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    accuracies['phase'] = 'normal'
    test_results['transient_compromise'].append(accuracies)
    
    for cluster_id in range(CFG['n_clusters']):
        acc_eq = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
        acc_eq['round'] = round_num
        acc_eq['phase'] = 'normal'
        test_results['transient_compromise_per_cluster_equal'][cluster_id].append(acc_eq)
        
        acc_dir = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
        acc_dir['round'] = round_num
        acc_dir['phase'] = 'normal'
        test_results['transient_compromise_per_cluster_dirichlet'][cluster_id].append(acc_dir)
    
    print(f"[Round {round_num}] Normal - Traffic: {accuracies['traffic_accuracy']:.4f}")

print("\nRound 10: CH0 compromised (but not detected yet)")
model_path = f'trained_models/hierarchical_equal/model_round_10.pkl'
with open(model_path, 'rb') as f:
    saved = pickle.load(f)

# Global + per-cluster
accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
accuracies['round'] = 10
accuracies['phase'] = 'compromised'
test_results['transient_compromise'].append(accuracies)

for cluster_id in range(CFG['n_clusters']):
    acc_eq = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
    acc_eq['round'] = 10
    acc_eq['phase'] = 'compromised'
    test_results['transient_compromise_per_cluster_equal'][cluster_id].append(acc_eq)
    
    acc_dir = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
    acc_dir['round'] = 10
    acc_dir['phase'] = 'compromised'
    test_results['transient_compromise_per_cluster_dirichlet'][cluster_id].append(acc_dir)

print(f"[Round 10] Compromised - Traffic: {accuracies['traffic_accuracy']:.4f}")

print("\nRound 11: Compromise detected, D&R-E phase begins")
print("Rounds 11-17: D&R-E Phase (7 rounds) - CH0 offline, cluster excluded")

for round_num in range(11, 18):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Global + per-cluster
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    accuracies['phase'] = 'dre'
    test_results['transient_compromise'].append(accuracies)
    
    for cluster_id in range(CFG['n_clusters']):
        acc_eq = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
        acc_eq['round'] = round_num
        acc_eq['phase'] = 'dre'
        test_results['transient_compromise_per_cluster_equal'][cluster_id].append(acc_eq)
        
        acc_dir = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
        acc_dir['round'] = round_num
        acc_dir['phase'] = 'dre'
        test_results['transient_compromise_per_cluster_dirichlet'][cluster_id].append(acc_dir)
    
    print(f"[Round {round_num}] D&R-E - Traffic: {accuracies['traffic_accuracy']:.4f}")

print("\nRounds 18-20: Continuity Phase (3 rounds) - Gradual re-entry")
continuity_rates = {18: 0.30, 19: 0.70, 20: 1.00}

for round_num in range(18, 21):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Global + per-cluster
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    accuracies['phase'] = 'continuity'
    accuracies['participation_rate'] = continuity_rates[round_num]
    test_results['transient_compromise'].append(accuracies)
    
    for cluster_id in range(CFG['n_clusters']):
        acc_eq = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
        acc_eq['round'] = round_num
        acc_eq['phase'] = 'continuity'
        acc_eq['participation_rate'] = continuity_rates[round_num]
        test_results['transient_compromise_per_cluster_equal'][cluster_id].append(acc_eq)
        
        acc_dir = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
        acc_dir['round'] = round_num
        acc_dir['phase'] = 'continuity'
        acc_dir['participation_rate'] = continuity_rates[round_num]
        test_results['transient_compromise_per_cluster_dirichlet'][cluster_id].append(acc_dir)
    
    print(f"[Round {round_num}] Continuity ({int(continuity_rates[round_num]*100)}%) - "
          f"Traffic: {accuracies['traffic_accuracy']:.4f}")

print("\nRounds 21-30: Re-stabilization Phase")
for round_num in range(21, 31):
    model_path = f'trained_models/hierarchical_equal/model_round_{round_num}.pkl'
    
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    # Global + per-cluster
    accuracies = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='global')
    accuracies['round'] = round_num
    accuracies['phase'] = 'stabilization'
    test_results['transient_compromise'].append(accuracies)
    
    for cluster_id in range(CFG['n_clusters']):
        acc_eq = evaluate_model_on_test(saved['weights'], test_data_equal, model_type='cluster', cluster_id=cluster_id)
        acc_eq['round'] = round_num
        acc_eq['phase'] = 'stabilization'
        test_results['transient_compromise_per_cluster_equal'][cluster_id].append(acc_eq)
        
        acc_dir = evaluate_model_on_test(saved['weights'], test_data_dirichlet, model_type='cluster', cluster_id=cluster_id)
        acc_dir['round'] = round_num
        acc_dir['phase'] = 'stabilization'
        test_results['transient_compromise_per_cluster_dirichlet'][cluster_id].append(acc_dir)
    
    print(f"[Round {round_num}] Stabilization - Traffic: {accuracies['traffic_accuracy']:.4f}")

print(f"\n✅ Transient Compromise complete: {len(test_results['transient_compromise'])} rounds")

print("\n" + "="*80)
print("✅ ALL TESTING COMPLETE!")
print("="*80)
print(f"\nAll test results stored in 'test_results' dictionary:")
print(f"  Single cluster: {len(test_results['single_cluster'])} rounds")
print(f"  Hierarchical equal: {len(test_results['hierarchical_equal'])} rounds")
print(f"  Hierarchical dirichlet: {len(test_results['hierarchical_dirichlet_global'])} rounds")
print(f"  Per-cluster: {len(test_results['hierarchical_dirichlet_per_cluster'][0])} rounds each")
print(f"  Compromise after convergence: {len(test_results['compromise_after_convergence'])} rounds")
print(f"  Transient compromise: {len(test_results['transient_compromise'])} rounds")
print(f"\nReady for visualization!")


## 32. Communication Cost & Convergence Summary

In [None]:
print("COMMUNICATION COST & CONVERGENCE ANALYSIS")

# - HierarchicalMTLFedAvgEnhanced instead of HierarchicalMTLFedAvg
# - CompromisedHierarchicalStrategyEnhanced instead of CompromisedHierarchicalStrategy

print("HOW TO ENABLE ENHANCED TRACKING")
print("""
To get communication cost and convergence tracking, modify your training calls:
1. For normal training (like the baseline test):
 Replace: strategy = HierarchicalMTLFedAvg(...)
 With: strategy = HierarchicalMTLFedAvgEnhanced(...)

2. For compromise tests:
 Replace: strategy = CompromisedHierarchicalStrategy(...)
 With: strategy = CompromisedHierarchicalStrategyEnhanced(...)

3. After training completes, get the summary:
 comm_summary = strategy.get_comm_summary()
 print(f"Model Size: {comm_summary['model_size_formatted']}")
 print(f"Total Communication: {comm_summary['total_cost_formatted']}")
 print(f"Convergence Round: {comm_summary['convergence_round']}")
""")

print("EXPECTED METRICS (Based on Study)")
print("""
From the study slides:
- Model Size: ~278.1 KB
- Formula: W = 2T(N·ω)
 where T = rounds, N = clients per round (600 with 100% participation), ω = model size
Example calculation for 100 rounds:
 W = 2 × 100 × 600 × 278.1 KB
 = 2 × 100 × 600 × 278,100 bytes
 = 33,372,000,000 bytes
 ≈ 33.37 GB

Convergence (from slides):
- Baseline: Converges around round 90
- Variance: < 1% over 10-round window
""")

print("STUDY PARAMETERS TO UPDATE")
print("""
To match the study exactly, update the following in the test cells:
Test 2 - CH Compromise After Convergence:
 Current: compromise_start_round=101, rounds=125, compromised_ch=1
 Update: compromise_start_round=111, rounds=120, compromised_ch=0
 
Test 3 - Transient CH Compromise:
 Current: compromise_start_round=50, rounds=125, compromised_ch=0
 Update: compromise_start_round=11, rounds=30, compromised_ch=0
 
Recovery Phases (from study slides):
 - Detection & Re-Election: 7 rounds (e.g., 111-117 or 11-17)
 - Continuity (Inter-Cluster Sync): 3 rounds (e.g., 118-120 or 18-20)
 - Stabilization: Gradual (30%, 70%, 100% participation)
""")

## 33. Enhanced Visualization with Convergence & Communication Metrics

In [None]:
import matplotlib.pyplot as plt

def plot_with_convergence(
    test_accuracies,
    title,
    convergence_round=None,
    compromise_round=None,
    compromise_end_round=None,
    subplot_ax=None
):
    """
    Plot accuracy with convergence and compromise markers.
    
    Args:
        test_accuracies: List of dicts with task accuracies + round index
        title: Plot title
        convergence_round: Round where convergence was detected
        compromise_round: Round where compromise started
        compromise_end_round: Round where compromise ended
        subplot_ax: Optional matplotlib axis for subplotting
    """

    # Safely extract fields
    rounds = [entry.get("round") for entry in test_accuracies]
    traffic_acc = [entry.get("traffic_accuracy", 0) for entry in test_accuracies]
    duration_acc = [entry.get("duration_accuracy", 0) for entry in test_accuracies]
    bandwidth_acc = [entry.get("bandwidth_accuracy", 0) for entry in test_accuracies]

    # Create figure/axis if needed
    if subplot_ax is None:
        plt.figure(figsize=(14, 8))
        ax = plt.gca()
    else:
        ax = subplot_ax

    # Plot accuracy curves
    ax.plot(rounds, traffic_acc,   color='green',  label='Traffic',   linewidth=2, marker='o', markersize=3)
    ax.plot(rounds, duration_acc,  color='blue',   label='Duration',  linewidth=2, marker='s', markersize=3)
    ax.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth', linewidth=2, marker='^', markersize=3)

    # Convergence marker
    if convergence_round is not None and convergence_round in rounds:
        idx = rounds.index(convergence_round)
        avg_acc = (traffic_acc[idx] + duration_acc[idx] + bandwidth_acc[idx]) / 3

        ax.axvline(
            x=convergence_round,
            color='green',
            linestyle=':',
            linewidth=2,
            label=f'Convergence (Round {convergence_round})',
            alpha=0.7
        )

        ax.scatter(
            [convergence_round],
            [avg_acc],
            color='green',
            s=200,
            marker='*',
            zorder=5,
            edgecolors='black'
        )

    # Compromise indicator
    if compromise_round is not None:
        ax.axvline(
            x=compromise_round,
            color='red',
            linestyle='--',
            linewidth=2.5,
            label=f'CH Compromised (Round {compromise_round})',
            alpha=0.8
        )

        if compromise_end_round is not None:
            ax.fill_between(
                [compromise_round, compromise_end_round],
                0, 1,
                alpha=0.15,
                color='red',
                label='Compromise Period'
            )

    # Labels & grid
    ax.set_xlabel("Rounds", fontsize=12)
    ax.set_ylabel("Test Accuracy", fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.legend(loc='best', fontsize=10)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_ylim(0, 1.05)

    if subplot_ax is None:
        plt.tight_layout()
        plt.show()


## ENHANCED TRACKING IMPLEMENTATION

### What Was Added:

**1. Communication Cost Tracking** 
- Formula: `W = 2T(N·ω)` from the study
 - W = total communication cost (bytes)
 - T = number of rounds
 - N = participating clients per round
 - ω = model size in bytes
 - 2 = bidirectional (upload + download)
- Real-time calculation per round
- Cumulative cost tracking
- Human-readable formatting (KB, MB, GB)

**2. Convergence Detection** 
- Automatic detection when accuracy stabilizes
- Window: 10 rounds
- Threshold: < 1% variance
- Reports convergence round
- Tracks time to convergence

**3. Enhanced Strategy Classes**
- `HierarchicalMTLFedAvgEnhanced`: Base strategy with tracking
- `CompromisedHierarchicalStrategyEnhanced`: Compromise strategy with tracking
- Both inherit from original strategies and add:
 - Communication cost calculation
 - Convergence detection
 - Summary generation via `get_comm_summary()`

**4. Visualization Enhancements**
- `plot_with_convergence()`: Shows convergence markers
- Green star at convergence point
- Red shading for compromise periods
- Vertical lines for key events

**5. Study Alignment**
- Updated compromise rounds to match study:
 - **Test 2**: Round 111 (after convergence)
 - **Test 3**: Round 11 (transient)
- CH0 compromise (matches study diagrams)
- Recovery phases documented:
 - Detection & Re-Election: 7 rounds
 - Continuity: 3 rounds
 - Stabilization: Gradual participation

### How to Use:

```python
# Example with baseline training
strategy = HierarchicalMTLFedAvgEnhanced(...) # Use enhanced version
history = fl.simulation.start_simulation(...)

# After training
summary = strategy.get_comm_summary()
print(f"Model Size: {summary['model_size_formatted']}")
print(f"Total Communication: {summary['total_cost_formatted']}")
print(f"Convergence Round: {summary['convergence_round']}")

# Visualize with convergence markers
plot_with_convergence(
 test_accuracies=strategy.test_accuracies,
 title="Baseline Training",
 convergence_round=summary['convergence_round']
)
```

### Expected Results (from study):
- **Model Size**: ~278.1 KB
- **100 rounds communication**: ~33.37 GB
- **Convergence**: Around round 90
- **Baseline performance**: 
 - Traffic: ~70%
 - Duration: ~60%
 - Bandwidth: ~95%

# QUICK REFERENCE: Implementation vs Study

## Comparison Table

| Feature | Study (Slides) | Our Implementation | Status |
|---------|---------------|-------------------|---------|
| **Architecture** | 3 clusters, CH1 global | 3 clusters, CH1 global | Match |
| **Total Clients** | 600 (200 per cluster) | 600 (200 per cluster) | Match |
| **Data Splits** | Dirichlet & Equal | Dirichlet & Equal | Match |
| **Model Size** | 278.1 KB | Auto-calculated | Ready |
| **Communication Formula** | W = 2T(N·ω) | W = 2T(N·ω) | Implemented |
| **Convergence Detection** | Round ~90, variance < 1% | 10-round window, < 1% | Implemented |
| **Baseline Rounds** | 100 | 100 | Match |
| **Convergence Compromise** | Round 111, CH0 | Need to update to 111, CH0 | Update needed |
| **Transient Compromise** | Round 11, CH0 | Need to update to 11, CH0 | Update needed |
| **Recovery Phases** | D&R-E (7) + Continuity (3) | Documented | Implementation optional |

## Test Parameters to Update

### Test 1: Baseline 
- Already correct: 100 rounds, no compromise

### Test 2: CH Compromise After Convergence 
**Current values:**
```python
rounds=125
compromise_start_round=101
compromised_ch=1
```

**Should be (to match study):**
```python
rounds=120 # 110 normal + 10 with compromise
compromise_start_round=111 # Compromise at round 111
compromised_ch=0 # CH0 compromised (not CH1)
```

### Test 3: Transient CH Compromise 
**Current values:**
```python
rounds=125
compromise_start_round=50
compromised_ch=0
```

**Should be (to match study):**
```python
rounds=30 # Shorter duration for transient test
compromise_start_round=11 # Early compromise at round 11
compromised_ch=0 # Keep CH0
```

## Communication Cost Calculation

**Formula:** `W = 2 × T × N × ω`

**Example for 100 rounds with 600 clients:**
- Model size (ω): ~278.1 KB = 278,100 bytes
- Rounds (T): 100
- Clients per round (N): 600
- Bidirectional factor: 2

**Calculation:**
```
W = 2 × 100 × 600 × 278,100
 = 33,372,000,000 bytes
 = 33.37 GB
```

**Per round:**
```
W_round = 2 × 600 × 278,100
 = 333,720,000 bytes
 = 333.72 MB
```

## 31. Visualization: All Tests Comparison

## Per-Cluster Visualizations: Normal Testing (Equal & Dirichlet Splits)


In [None]:
# Per-Cluster Normal Testing: Equal and Dirichlet Splits (100 rounds)
import matplotlib.pyplot as plt

# Create figure with 2 rows x 3 columns
fig, axes = plt.subplots(2, 3, figsize=(20, 12))
fig.suptitle('Per-Cluster Performance: Normal Testing (100 Rounds)', fontsize=16, fontweight='bold')

# Row 1: Equal Split
for cluster_id in range(3):
    ax = axes[0, cluster_id]
    data = test_results['hierarchical_dirichlet_per_cluster'][cluster_id][:100]
    
    traffic = [item['traffic_accuracy'] for item in data]
    duration = [item['duration_accuracy'] for item in data]
    bandwidth = [item['bandwidth_accuracy'] for item in data]
    rounds = [item['round'] for item in data]
    
    ax.plot(rounds, bandwidth, 'orange', label='Bandwidth', linewidth=2)
    ax.plot(rounds, duration, 'cyan', label='Duration', linewidth=2)
    ax.plot(rounds, traffic, 'teal', label='Traffic', linewidth=2)
    
    ax.set_xlabel('Rounds', fontsize=11)
    ax.set_ylabel('Training Accuracy', fontsize=11)
    ax.set_title(f'Cluster {cluster_id} - Equal Split', fontsize=12, fontweight='bold')
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 100)
    ax.set_ylim(0, 1.1)

# Row 2: Dirichlet Split  
for cluster_id in range(3):
    ax = axes[1, cluster_id]
    data = test_results['hierarchical_dirichlet_per_cluster'][cluster_id][:100]
    
    traffic = [item['traffic_accuracy'] for item in data]
    duration = [item['duration_accuracy'] for item in data]
    bandwidth = [item['bandwidth_accuracy'] for item in data]
    rounds = [item['round'] for item in data]
    
    ax.plot(rounds, bandwidth, 'orange', label='Bandwidth', linewidth=2)
    ax.plot(rounds, duration, 'cyan', label='Duration', linewidth=2)
    ax.plot(rounds, traffic, 'teal', label='Traffic', linewidth=2)
    
    ax.set_xlabel('Rounds', fontsize=11)
    ax.set_ylabel('Training Accuracy', fontsize=11)
    ax.set_title(f'Cluster {cluster_id} - Dirichlet Split', fontsize=12, fontweight='bold')
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 100)
    ax.set_ylim(0, 1.1)

plt.tight_layout()
plt.show()

print("✅ Per-cluster normal testing visualizations complete")


## Per-Cluster: CH Compromise After Convergence (Equal & Dirichlet Splits)


In [None]:
# Per-Cluster CH Compromise After Convergence (125 rounds)
import matplotlib.pyplot as plt

# Create figure with 2 rows x 3 columns
fig, axes = plt.subplots(2, 3, figsize=(22, 12))
fig.suptitle('Per-Cluster: CH Compromise After Convergence (125 Rounds)', fontsize=16, fontweight='bold')

# Row 1: Equal Split
for cluster_id in range(3):
    ax = axes[0, cluster_id]
    data = test_results['compromise_after_convergence_per_cluster_equal'][cluster_id]
    
    traffic = [item['traffic_accuracy'] for item in data]
    duration = [item['duration_accuracy'] for item in data]
    bandwidth = [item['bandwidth_accuracy'] for item in data]
    rounds = [item['round'] for item in data]
    
    ax.plot(rounds, bandwidth, 'orange', label='Bandwidth', linewidth=2)
    ax.plot(rounds, duration, 'cyan', label='Duration', linewidth=2)
    ax.plot(rounds, traffic, 'teal', label='Traffic', linewidth=2)
    
    # Phase markers
    ax.axvline(x=90, color='gray', linestyle=':', linewidth=1, alpha=0.5)
    ax.text(90, 0.95, 'round 90', rotation=90, va='top', fontsize=8, alpha=0.7)
    
    ax.axvspan(111, 118, alpha=0.15, color='pink', label='D&R-E\n(111-117)')
    ax.axvspan(118, 121, alpha=0.15, color='yellow', label='Continuity\n(118-120)')
    
    ax.set_xlabel('Rounds', fontsize=10)
    ax.set_ylabel('Test Accuracy', fontsize=10)
    ax.set_title(f'Cluster {cluster_id} - Equal Split', fontsize=11, fontweight='bold')
    ax.legend(loc='lower right', fontsize=8)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 125)

# Row 2: Dirichlet Split
for cluster_id in range(3):
    ax = axes[1, cluster_id]
    data = test_results['compromise_after_convergence_per_cluster_dirichlet'][cluster_id]
    
    traffic = [item['traffic_accuracy'] for item in data]
    duration = [item['duration_accuracy'] for item in data]
    bandwidth = [item['bandwidth_accuracy'] for item in data]
    rounds = [item['round'] for item in data]
    
    ax.plot(rounds, bandwidth, 'orange', label='Bandwidth', linewidth=2)
    ax.plot(rounds, duration, 'cyan', label='Duration', linewidth=2)
    ax.plot(rounds, traffic, 'teal', label='Traffic', linewidth=2)
    
    # Phase markers
    ax.axvline(x=90, color='gray', linestyle=':', linewidth=1, alpha=0.5)
    ax.text(90, 0.95, 'round 90', rotation=90, va='top', fontsize=8, alpha=0.7)
    
    ax.axvspan(111, 118, alpha=0.15, color='pink', label='D&R-E\n(111-117)')
    ax.axvspan(118, 121, alpha=0.15, color='yellow', label='Continuity\n(118-120)')
    
    ax.set_xlabel('Rounds', fontsize=10)
    ax.set_ylabel('Test Accuracy', fontsize=10)
    ax.set_title(f'Cluster {cluster_id} - Dirichlet Split', fontsize=11, fontweight='bold')
    ax.legend(loc='lower right', fontsize=8)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 125)

plt.tight_layout()
plt.show()

print("✅ Per-cluster CH compromise after convergence visualizations complete")


## Per-Cluster: Transient CH Compromise (Equal & Dirichlet Splits)


In [None]:
# Per-Cluster Transient CH Compromise (30 rounds)
import matplotlib.pyplot as plt

# Create figure with 2 rows x 3 columns
fig, axes = plt.subplots(2, 3, figsize=(22, 12))
fig.suptitle('Per-Cluster: Transient CH Compromise (30 Rounds)', fontsize=16, fontweight='bold')

# Row 1: Equal Split
for cluster_id in range(3):
    ax = axes[0, cluster_id]
    data = test_results['transient_compromise_per_cluster_equal'][cluster_id]
    
    traffic = [item['traffic_accuracy'] for item in data]
    duration = [item['duration_accuracy'] for item in data]
    bandwidth = [item['bandwidth_accuracy'] for item in data]
    rounds = [item['round'] for item in data]
    
    ax.plot(rounds, bandwidth, 'orange', label='Bandwidth', linewidth=2)
    ax.plot(rounds, duration, 'cyan', label='Duration', linewidth=2)
    ax.plot(rounds, traffic, 'teal', label='Traffic', linewidth=2)
    
    # Phase markers
    ax.axvspan(10, 11, alpha=0.2, color='gray', label='Detection\n(round 10)')
    ax.axvspan(11, 18, alpha=0.15, color='pink', label='D&R-E\n(11-17)')
    ax.axvspan(18, 21, alpha=0.15, color='yellow', label='Continuity\n(18-20)')
    ax.axvspan(21, 30, alpha=0.10, color='lightgreen', label='Stabilization\n(21-30)')
    
    ax.set_xlabel('Global Rounds', fontsize=10)
    ax.set_ylabel('Test Accuracy', fontsize=10)
    ax.set_title(f'Cluster {cluster_id} - Equal Split', fontsize=11, fontweight='bold')
    ax.legend(loc='lower right', fontsize=7)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 30)

# Row 2: Dirichlet Split
for cluster_id in range(3):
    ax = axes[1, cluster_id]
    data = test_results['transient_compromise_per_cluster_dirichlet'][cluster_id]
    
    traffic = [item['traffic_accuracy'] for item in data]
    duration = [item['duration_accuracy'] for item in data]
    bandwidth = [item['bandwidth_accuracy'] for item in data]
    rounds = [item['round'] for item in data]
    
    ax.plot(rounds, bandwidth, 'orange', label='Bandwidth', linewidth=2)
    ax.plot(rounds, duration, 'cyan', label='Duration', linewidth=2)
    ax.plot(rounds, traffic, 'teal', label='Traffic', linewidth=2)
    
    # Phase markers  
    ax.axvspan(10, 11, alpha=0.2, color='gray', label='Detection\n(round 10)')
    ax.axvspan(11, 18, alpha=0.15, color='pink', label='D&R-E\n(11-17)')
    ax.axvspan(18, 21, alpha=0.15, color='yellow', label='Continuity\n(18-20)')
    ax.axvspan(21, 30, alpha=0.10, color='lightgreen', label='Stabilization\n(21-30)')
    
    ax.set_xlabel('Global Rounds', fontsize=10)
    ax.set_ylabel('Test Accuracy', fontsize=10)
    ax.set_title(f'Cluster {cluster_id} - Dirichlet Split', fontsize=11, fontweight='bold')
    ax.legend(loc='lower right', fontsize=7)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 30)

plt.tight_layout()
plt.show()

print("✅ Per-cluster transient compromise visualizations complete")


 the graph

In [None]:
print("SUMMARY STATISTICS - ALL TESTS")

print("\n TEST 1: BASELINE (100 Rounds)")
print(f"Final Accuracies (Round 100):")
print(f" Traffic: {baseline_traffic[99]:.4f} ({baseline_traffic[99]*100:.2f}%)")
print(f" Duration: {baseline_duration[99]:.4f} ({baseline_duration[99]*100:.2f}%)")
print(f" Bandwidth: {baseline_bandwidth[99]:.4f} ({baseline_bandwidth[99]*100:.2f}%)")
avg_baseline = (baseline_traffic[99] + baseline_duration[99] + baseline_bandwidth[99]) / 3
print(f" Average: {avg_baseline:.4f} ({avg_baseline*100:.2f}%)")

print("\n TEST 2: CH COMPROMISE AFTER CONVERGENCE (CH0 at Round 111)")
print(f"Before Compromise (Round 110):")
print(f" Traffic: {compromise_after_traffic[109]:.4f}")
print(f" Duration: {compromise_after_duration[109]:.4f}")
print(f" Bandwidth: {compromise_after_bandwidth[109]:.4f}")
avg_before = (compromise_after_traffic[109] + compromise_after_duration[109] + compromise_after_bandwidth[109]) / 3
print(f" Average: {avg_before:.4f}")

print(f"\nAfter Compromise (Round 120):")
print(f" Traffic: {compromise_after_traffic[-1]:.4f}")
print(f" Duration: {compromise_after_duration[-1]:.4f}")
print(f" Bandwidth: {compromise_after_bandwidth[-1]:.4f}")
avg_after = (compromise_after_traffic[-1] + compromise_after_duration[-1] + compromise_after_bandwidth[-1]) / 3
print(f" Average: {avg_after:.4f}")

print(f"\nImpact of Compromise:")
print(f" Traffic: {(compromise_after_traffic[109] - compromise_after_traffic[-1])*100:.2f}% drop")
print(f" Duration: {(compromise_after_duration[109] - compromise_after_duration[-1])*100:.2f}% drop")
print(f" Bandwidth: {(compromise_after_bandwidth[109] - compromise_after_bandwidth[-1])*100:.2f}% drop")
print(f" Average: {(avg_before - avg_after)*100:.2f}% drop")

print("\n TEST 3: TRANSIENT CH COMPROMISE (CH0 from Round 11)")
print(f"Before Compromise (Round 10):")
print(f" Traffic: {transient_traffic[9]:.4f}")
print(f" Duration: {transient_duration[9]:.4f}")
print(f" Bandwidth: {transient_bandwidth[9]:.4f}")
avg_trans_before = (transient_traffic[9] + transient_duration[9] + transient_bandwidth[9]) / 3
print(f" Average: {avg_trans_before:.4f}")

print(f"\nAfter Compromise (Round 30):")
print(f" Traffic: {transient_traffic[-1]:.4f}")
print(f" Duration: {transient_duration[-1]:.4f}")
print(f" Bandwidth: {transient_bandwidth[-1]:.4f}")
avg_trans_after = (transient_traffic[-1] + transient_duration[-1] + transient_bandwidth[-1]) / 3
print(f" Average: {avg_trans_after:.4f}")

print(f"\nImpact of Compromise:")
print(f" Traffic: {(transient_traffic[9] - transient_traffic[-1])*100:.2f}% difference")
print(f" Duration: {(transient_duration[9] - transient_duration[-1])*100:.2f}% difference")
print(f" Bandwidth: {(transient_bandwidth[9] - transient_bandwidth[-1])*100:.2f}% difference")
print(f" Average: {(avg_trans_before - avg_trans_after)*100:.2f}% difference")

print("KEY INSIGHTS")
print("1. Baseline converges around round 100 with stable accuracy")
print("2. CH compromise after convergence shows immediate performance degradation")
print("3. Transient compromise during training affects learning trajectory")
print("4. Different CHs (CH0 vs CH1) may have different impact levels")
print("5. The hierarchical architecture shows resilience/vulnerability patterns")

## All CH Compromisation Tests Complete!

### Test Summary (Updated to Match Study):
- **Test 1**: Baseline (100 rounds) - Normal convergence without compromise
- **Test 2**: CH Compromise After Convergence (120 rounds) - **CH0 compromised at round 111** (matches study slides 7-9)
- **Test 3**: Transient CH Compromise (30 rounds) - **CH0 compromised at round 11** (matches study slide 12)

### Visualization Layout:
**Row 1 (First 100 Rounds):**
- Graph 1: Traffic Classification (Baseline vs Pre-Compromise)
- Graph 2: Duration Classification (Baseline vs Pre-Compromise)

**Row 2 (Full 120 Rounds - After Convergence):**
- Graph 3: Traffic with **CH0 compromise at round 111**
- Graph 4: Duration & Bandwidth with **CH0 compromise at round 111**

**Row 3 (Full 30 Rounds - Transient):**
- Graph 5: Traffic with **CH0 compromise at round 11**
- Graph 6: Duration & Bandwidth with **CH0 compromise at round 11**

### Architecture Used:
- **3 Clusters**: 600 clients total (200 per cluster)
- **Cluster Heads**: CH0, CH1, CH2
- **Global Aggregator**: CH1
- **Two-Tier Hierarchy**:
 - Tier 1: Members → CH (local aggregation)
 - Tier 2: CH0, CH2 → CH1 (global) → CH0, CH2 → Members

### Compromise Method:
- **Type**: Label flipping (parameter sign inversion)
- **Impact**: Malicious CH sends poisoned model parameters
- **Detection**: Observable through accuracy degradation

### Key Findings (Based on Study):
1. Baseline achieves stable convergence around round 90 (detected automatically)
2. Post-convergence compromise (round 111) shows immediate performance degradation
3. Transient compromise (round 11) affects learning trajectory from early stages
4. **CH0 compromise** studied (local cluster head impact)
5. System shows vulnerability to CH-level attacks, requiring recovery mechanisms
6. Study shows recovery phases: Detection & Re-Election (7 rounds) + Continuity (3 rounds)

### Next Experiments:
- Test different compromise types (random_noise, model_poison)
- Compromise different CHs (CH2, multiple CHs simultaneously)
- Implement defense mechanisms (anomaly detection, secure aggregation)
- Compare impact under equal vs Dirichlet data splits

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Check if test results are available
if 'test_results' not in locals():
    print("⚠️ Test results not available. Please run the testing phase first.")
else:
    print("="*80)
    print("PLOTTING PER-TASK TRAINING ACCURACY GRAPHS")
    print("="*80)
    
    # Define experiments to plot
    experiments = {
        'single_cluster': 'Single Cluster',
        'hierarchical_equal': 'Hierarchical (Equal Split)'
    }
    
    # Create figure with 3 subplots (one per task)
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    tasks = [
        ('traffic_accuracy', 'Traffic Classification', 'green'),
        ('duration_accuracy', 'Flow Duration Classification', 'blue'),
        ('bandwidth_accuracy', 'Bandwidth Classification', 'orange')
    ]
    
    for idx, (task_metric, task_title, color) in enumerate(tasks):
        ax = axes[idx]
        
        # Plot each experiment
        for exp_key, exp_label in experiments.items():
            if exp_key in test_results and test_results[exp_key]:
                data = test_results[exp_key]
                rounds = [item['round'] for item in data]
                accuracies = [item[task_metric] for item in data]
                
                linestyle = '-' if exp_key == 'single_cluster' else '--'
                marker = 'o' if exp_key == 'single_cluster' else 's'
                
                ax.plot(rounds, accuracies, color=color, label=exp_label,
                       linewidth=2, marker=marker, markersize=4, 
                       linestyle=linestyle, markevery=10)
        
        # Format subplot
        ax.set_xlabel('Rounds', fontsize=11)
        ax.set_ylabel('Accuracy', fontsize=11)
        ax.set_title(task_title, fontsize=12, fontweight='bold')
        ax.legend(loc='lower right', fontsize=9)
        ax.grid(True, alpha=0.3, linestyle='--')
        
        # Set y-axis limits with padding
        if exp_key in test_results and test_results[exp_key]:
            all_acc = [item[task_metric] for item in test_results['single_cluster']]
            if 'hierarchical_equal' in test_results:
                all_acc += [item[task_metric] for item in test_results['hierarchical_equal']]
            y_min = max(0.0, min(all_acc) - 0.05)
            y_max = min(1.0, max(all_acc) + 0.05)
            ax.set_ylim(y_min, y_max)
    
    plt.suptitle('Federated Multi-Task Learning - Per-Task Training Accuracy', 
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics for each task
    print("\n" + "="*80)
    print("PER-TASK SUMMARY STATISTICS")
    print("="*80)
    
    for task_metric, task_title, _ in tasks:
        print(f"\n{task_title}:")
        print("-" * 80)
        
        for exp_key, exp_label in experiments.items():
            if exp_key in test_results and test_results[exp_key]:
                data = test_results[exp_key]
                accuracies = [item[task_metric] for item in data]
                rounds = [item['round'] for item in data]
                
                print(f"\n  {exp_label}:")
                print(f"    First round:  {accuracies[0]:.4f} ({accuracies[0]*100:.2f}%)")
                print(f"    Last round:   {accuracies[-1]:.4f} ({accuracies[-1]*100:.2f}%)")
                print(f"    Best:         {max(accuracies):.4f} ({max(accuracies)*100:.2f}%) at Round {rounds[np.argmax(accuracies)]}")
                print(f"    Improvement:  +{accuracies[-1] - accuracies[0]:.4f}")
    
    print("\n" + "="*80)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Extract data from single cluster training history
if 'history_single' in locals():
    # Debug: Print available attributes
    print("Available history_single attributes:")
    print([attr for attr in dir(history_single) if not attr.startswith('_')])
    print()
    
    # Check both metrics_distributed_fit (training) and metrics_distributed (evaluation)
    metrics_fit = getattr(history_single, 'metrics_distributed_fit', {})
    metrics_eval = getattr(history_single, 'metrics_distributed', {})
    
    print("Available metrics in metrics_distributed_fit:", list(metrics_fit.keys()) if metrics_fit else "None")
    print("Available metrics in metrics_distributed:", list(metrics_eval.keys()) if metrics_eval else "None")
    print()
    
    # Try to get rounds from either source
    rounds = []
    traffic_acc = []
    duration_acc = []
    bandwidth_acc = []
    
    # First, try to get per-task accuracies from evaluation metrics
    if metrics_eval:
        if 'traffic_accuracy' in metrics_eval:
            rounds = [r for r, _ in metrics_eval['traffic_accuracy']]
            traffic_acc = [float(v) for _, v in metrics_eval['traffic_accuracy']]
            duration_acc = [float(v) for _, v in metrics_eval['duration_accuracy']]
            bandwidth_acc = [float(v) for _, v in metrics_eval['bandwidth_accuracy']]
        elif 'accuracy' in metrics_eval:
            rounds = [r for r, _ in metrics_eval['accuracy']]
    
    # If not found, try training metrics
    if not rounds and metrics_fit:
        if 'traffic_accuracy' in metrics_fit:
            rounds = [r for r, _ in metrics_fit['traffic_accuracy']]
            traffic_acc = [float(v) for _, v in metrics_fit['traffic_accuracy']]
            duration_acc = [float(v) for _, v in metrics_fit['duration_accuracy']]
            bandwidth_acc = [float(v) for _, v in metrics_fit['bandwidth_accuracy']]
        elif 'accuracy' in metrics_fit:
            rounds = [r for r, _ in metrics_fit['accuracy']]
    
    # Fallback: Use test_results if available
    if not rounds and 'test_results' in locals() and 'single_cluster' in test_results:
        print("Using test_results for plotting...")
        data = test_results['single_cluster']
        rounds = [item['round'] for item in data]
        traffic_acc = [item['traffic_accuracy'] for item in data]
        duration_acc = [item['duration_accuracy'] for item in data]
        bandwidth_acc = [item['bandwidth_accuracy'] for item in data]
    
    if not rounds:
        print("⚠️ No training metrics found in history_single or test_results")
        print("Please ensure training has been completed.")
    else:
        print(f"✅ Single cluster training: {len(rounds)} rounds")
        
        # Create styled plot matching your reference
        plt.figure(figsize=(12, 8))
        
        # Plot per-task accuracies if available
        if traffic_acc and duration_acc and bandwidth_acc:
            # Calculate y-axis bounds
            all_acc = traffic_acc + duration_acc + bandwidth_acc
            y_min = max(0.0, min(all_acc) - 0.05)
            y_max = min(1.0, max(all_acc) + 0.05)
            
            plt.plot(rounds, traffic_acc, color='green', label='Traffic Classification', 
                    linewidth=2, marker='o', markersize=4)
            plt.plot(rounds, duration_acc, color='blue', label='Flow Duration Classification', 
                    linewidth=2, marker='s', markersize=4)
            plt.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth Classification', 
                    linewidth=2, marker='^', markersize=4)
            
            plt.ylim(y_min, y_max)
        else:
            # Fallback: plot overall accuracy if available
            if 'accuracy' in metrics_eval or 'accuracy' in metrics_fit:
                acc_metrics = metrics_eval.get('accuracy', metrics_fit.get('accuracy', []))
                if acc_metrics:
                    acc_values = [float(v) for _, v in acc_metrics]
                    plt.plot(rounds, acc_values, color='blue', label='Overall Accuracy', 
                            linewidth=2, marker='o', markersize=4)
                    plt.ylim(0, 1.05)
            else:
                print("⚠️ No accuracy metrics available to plot")
        
        plt.xlabel('Rounds', fontsize=12)
        plt.ylabel('Accuracy', fontsize=12)
        plt.title('Single Cluster Training - Federated Multi-Task Learning', fontsize=14, fontweight='bold')
        plt.legend(loc='lower right', fontsize=11)
        plt.grid(True, alpha=0.3, linestyle='--')
        plt.tight_layout()
        plt.show()
        
        print("="*80)
        print("SINGLE CLUSTER TRAINING SUMMARY")
        print("="*80)
        print(f"Total Rounds: {len(rounds)}")
        if traffic_acc and duration_acc and bandwidth_acc:
            print(f"\nFinal Accuracies:")
            print(f"  Traffic:   {traffic_acc[-1]:.4f} ({traffic_acc[-1]*100:.2f}%)")
            print(f"  Duration:  {duration_acc[-1]:.4f} ({duration_acc[-1]*100:.2f}%)")
            print(f"  Bandwidth: {bandwidth_acc[-1]:.4f} ({bandwidth_acc[-1]*100:.2f}%)")
        print("="*80)
else:
    print("⚠️ Training history not available yet - run training cells first")

In [None]:
#Test Graphs

In [None]:
import matplotlib.pyplot as plt
import numpy as np

if 'test_results' in locals() and 'single_cluster' in test_results:
    data = test_results['single_cluster']
    rounds = [item['round'] for item in data]
    traffic_acc = [item['traffic_accuracy'] for item in data]
    duration_acc = [item['duration_accuracy'] for item in data]
    bandwidth_acc = [item['bandwidth_accuracy'] for item in data]
    
    # Calculate y-axis bounds
    all_acc = traffic_acc + duration_acc + bandwidth_acc
    y_min = max(0.0, min(all_acc) - 0.05)
    y_max = min(1.0, max(all_acc) + 0.05)
    
    plt.figure(figsize=(12, 8))
    plt.plot(rounds, traffic_acc, color='green', label='Traffic Classification', 
            linewidth=2, marker='o', markersize=4)
    plt.plot(rounds, duration_acc, color='blue', label='Flow Duration Classification', 
            linewidth=2, marker='s', markersize=4)
    plt.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth Classification', 
            linewidth=2, marker='^', markersize=4)
    
    plt.xlabel('Rounds', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.title('Single Cluster Testing - Baseline Performance', fontsize=14, fontweight='bold')
    plt.legend(loc='lower right', fontsize=11)
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.ylim(y_min, y_max)
    plt.tight_layout()
    plt.show()
    
    print("="*80)
    print("SINGLE CLUSTER TEST RESULTS")
    print("="*80)
    print(f"Total Rounds: {len(rounds)}")
    print(f"\nFinal Accuracies (Round {rounds[-1]}):")
    print(f"  Traffic:   {traffic_acc[-1]:.4f} ({traffic_acc[-1]*100:.2f}%)")
    print(f"  Duration:  {duration_acc[-1]:.4f} ({duration_acc[-1]*100:.2f}%)")
    print(f"  Bandwidth: {bandwidth_acc[-1]:.4f} ({bandwidth_acc[-1]*100:.2f}%)")
    print(f"\nBest Accuracies:")
    print(f"  Traffic:   {max(traffic_acc):.4f} at Round {rounds[np.argmax(traffic_acc)]}")
    print(f"  Duration:  {max(duration_acc):.4f} at Round {rounds[np.argmax(duration_acc)]}")
    print(f"  Bandwidth: {max(bandwidth_acc):.4f} at Round {rounds[np.argmax(bandwidth_acc)]}")
    print("="*80)
else:
    print("⚠️ Run testing cells first to generate test_results")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

cluster_id = 0
if 'test_results' in locals() and 'hierarchical_dirichlet_per_cluster' in test_results:
    data = test_results['hierarchical_dirichlet_per_cluster'][cluster_id][:100]
    rounds = [item['round'] for item in data]
    traffic_acc = [item['traffic_accuracy'] for item in data]
    duration_acc = [item['duration_accuracy'] for item in data]
    bandwidth_acc = [item['bandwidth_accuracy'] for item in data]
    
    all_acc = traffic_acc + duration_acc + bandwidth_acc
    y_min = max(0.0, min(all_acc) - 0.05)
    y_max = min(1.0, max(all_acc) + 0.05)
    
    plt.figure(figsize=(12, 8))
    plt.plot(rounds, traffic_acc, color='teal', label='Traffic Classification', 
            linewidth=2, marker='o', markersize=4)
    plt.plot(rounds, duration_acc, color='cyan', label='Flow Duration Classification', 
            linewidth=2, marker='s', markersize=4)
    plt.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth Classification', 
            linewidth=2, marker='^', markersize=4)
    
    plt.xlabel('Rounds', fontsize=12)
    plt.ylabel('Training Accuracy', fontsize=12)
    plt.title(f'Cluster {cluster_id} - Equal Split (100 Rounds)', fontsize=14, fontweight='bold')
    plt.legend(loc='lower right', fontsize=11)
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.ylim(y_min, y_max)
    plt.tight_layout()
    plt.show()
    
    print("="*80)
    print(f"CLUSTER {cluster_id} - EQUAL SPLIT TEST RESULTS")
    print("="*80)
    print(f"Total Rounds: {len(rounds)}")
    print(f"\nFinal Accuracies (Round {rounds[-1]}):")
    print(f"  Traffic:   {traffic_acc[-1]:.4f} ({traffic_acc[-1]*100:.2f}%)")
    print(f"  Duration:  {duration_acc[-1]:.4f} ({duration_acc[-1]*100:.2f}%)")
    print(f"  Bandwidth: {bandwidth_acc[-1]:.4f} ({bandwidth_acc[-1]*100:.2f}%)")
    print("="*80)
else:
    print("⚠️ Run per-cluster testing cells first")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

cluster_id = 2
if 'test_results' in locals() and 'hierarchical_dirichlet_per_cluster' in test_results:
    data = test_results['hierarchical_dirichlet_per_cluster'][cluster_id][:100]
    rounds = [item['round'] for item in data]
    traffic_acc = [item['traffic_accuracy'] for item in data]
    duration_acc = [item['duration_accuracy'] for item in data]
    bandwidth_acc = [item['bandwidth_accuracy'] for item in data]
    
    all_acc = traffic_acc + duration_acc + bandwidth_acc
    y_min = max(0.0, min(all_acc) - 0.05)
    y_max = min(1.0, max(all_acc) + 0.05)
    
    plt.figure(figsize=(12, 8))
    plt.plot(rounds, traffic_acc, color='teal', label='Traffic Classification', 
            linewidth=2, marker='o', markersize=4)
    plt.plot(rounds, duration_acc, color='cyan', label='Flow Duration Classification', 
            linewidth=2, marker='s', markersize=4)
    plt.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth Classification', 
            linewidth=2, marker='^', markersize=4)
    
    plt.xlabel('Rounds', fontsize=12)
    plt.ylabel('Training Accuracy', fontsize=12)
    plt.title(f'Cluster {cluster_id} - Equal Split (100 Rounds)', fontsize=14, fontweight='bold')
    plt.legend(loc='lower right', fontsize=11)
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.ylim(y_min, y_max)
    plt.tight_layout()
    plt.show()
    
    print("="*80)
    print(f"CLUSTER {cluster_id} - EQUAL SPLIT TEST RESULTS")
    print("="*80)
    print(f"Total Rounds: {len(rounds)}")
    print(f"\nFinal Accuracies (Round {rounds[-1]}):")
    print(f"  Traffic:   {traffic_acc[-1]:.4f} ({traffic_acc[-1]*100:.2f}%)")
    print(f"  Duration:  {duration_acc[-1]:.4f} ({duration_acc[-1]*100:.2f}%)")
    print(f"  Bandwidth: {bandwidth_acc[-1]:.4f} ({bandwidth_acc[-1]*100:.2f}%)")
    print("="*80)
else:
    print("⚠️ Run per-cluster testing cells first")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

cluster_id = 1
if 'test_results' in locals() and 'hierarchical_dirichlet_per_cluster' in test_results:
    data = test_results['hierarchical_dirichlet_per_cluster'][cluster_id][:100]
    rounds = [item['round'] for item in data]
    traffic_acc = [item['traffic_accuracy'] for item in data]
    duration_acc = [item['duration_accuracy'] for item in data]
    bandwidth_acc = [item['bandwidth_accuracy'] for item in data]
    
    all_acc = traffic_acc + duration_acc + bandwidth_acc
    y_min = max(0.0, min(all_acc) - 0.05)
    y_max = min(1.0, max(all_acc) + 0.05)
    
    plt.figure(figsize=(12, 8))
    plt.plot(rounds, traffic_acc, color='teal', label='Traffic Classification', 
            linewidth=2, marker='o', markersize=4)
    plt.plot(rounds, duration_acc, color='cyan', label='Flow Duration Classification', 
            linewidth=2, marker='s', markersize=4)
    plt.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth Classification', 
            linewidth=2, marker='^', markersize=4)
    
    plt.xlabel('Rounds', fontsize=12)
    plt.ylabel('Training Accuracy', fontsize=12)
    plt.title(f'Cluster {cluster_id} - Equal Split (100 Rounds)', fontsize=14, fontweight='bold')
    plt.legend(loc='lower right', fontsize=11)
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.ylim(y_min, y_max)
    plt.tight_layout()
    plt.show()
    
    print("="*80)
    print(f"CLUSTER {cluster_id} - EQUAL SPLIT TEST RESULTS")
    print("="*80)
    print(f"Total Rounds: {len(rounds)}")
    print(f"\nFinal Accuracies (Round {rounds[-1]}):")
    print(f"  Traffic:   {traffic_acc[-1]:.4f} ({traffic_acc[-1]*100:.2f}%)")
    print(f"  Duration:  {duration_acc[-1]:.4f} ({duration_acc[-1]*100:.2f}%)")
    print(f"  Bandwidth: {bandwidth_acc[-1]:.4f} ({bandwidth_acc[-1]*100:.2f}%)")
    print("="*80)
else:
    print("⚠️ Run per-cluster testing cells first")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

cluster_id = 0
if 'test_results' in locals() and 'compromise_after_convergence_per_cluster_equal' in test_results:
    data = test_results['compromise_after_convergence_per_cluster_equal'][cluster_id]
    rounds = [item['round'] for item in data]
    traffic_acc = [item['traffic_accuracy'] for item in data]
    duration_acc = [item['duration_accuracy'] for item in data]
    bandwidth_acc = [item['bandwidth_accuracy'] for item in data]
    
    all_acc = traffic_acc + duration_acc + bandwidth_acc
    y_min = max(0.0, min(all_acc) - 0.05)
    y_max = min(1.0, max(all_acc) + 0.05)
    
    plt.figure(figsize=(14, 8))
    plt.plot(rounds, traffic_acc, color='teal', label='Traffic Classification', 
            linewidth=2, marker='o', markersize=3)
    plt.plot(rounds, duration_acc, color='cyan', label='Flow Duration Classification', 
            linewidth=2, marker='s', markersize=3)
    plt.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth Classification', 
            linewidth=2, marker='^', markersize=3)
    
    # Phase markers
    plt.axvline(x=90, color='gray', linestyle=':', linewidth=1.5, alpha=0.7, label='Round 90 (Convergence)')
    plt.axvspan(111, 118, alpha=0.15, color='pink', label='D&R-E Phase (111-117)')
    plt.axvspan(118, 121, alpha=0.15, color='yellow', label='Continuity (118-120)')
    plt.axvspan(121, 125, alpha=0.10, color='lightgreen', label='Stabilization (121-125)')
    
    plt.xlabel('Global Rounds', fontsize=12)
    plt.ylabel('Test Accuracy', fontsize=12)
    plt.title(f'CH Compromise After Convergence - Cluster {cluster_id} (Equal Split)', 
             fontsize=14, fontweight='bold')
    plt.legend(loc='best', fontsize=10, ncol=2)
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.ylim(y_min, y_max)
    plt.xlim(0, 125)
    plt.tight_layout()
    plt.show()
    
    print("="*80)
    print(f"CH COMPROMISE - CLUSTER {cluster_id} (EQUAL SPLIT)")
    print("="*80)
    print(f"Total Rounds: {len(rounds)}")
    print(f"\nBefore Compromise (Round 110):")
    if len(traffic_acc) > 109:
        print(f"  Traffic:   {traffic_acc[109]:.4f} ({traffic_acc[109]*100:.2f}%)")
        print(f"  Duration:  {duration_acc[109]:.4f} ({duration_acc[109]*100:.2f}%)")
        print(f"  Bandwidth: {bandwidth_acc[109]:.4f} ({bandwidth_acc[109]*100:.2f}%)")
    print(f"\nAfter Recovery (Round 125):")
    print(f"  Traffic:   {traffic_acc[-1]:.4f} ({traffic_acc[-1]*100:.2f}%)")
    print(f"  Duration:  {duration_acc[-1]:.4f} ({duration_acc[-1]*100:.2f}%)")
    print(f"  Bandwidth: {bandwidth_acc[-1]:.4f} ({bandwidth_acc[-1]*100:.2f}%)")
    print("="*80)
else:
    print("⚠️ Run CH compromise testing cell first")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

cluster_id = 1
if 'test_results' in locals() and 'compromise_after_convergence_per_cluster_equal' in test_results:
    data = test_results['compromise_after_convergence_per_cluster_equal'][cluster_id]
    rounds = [item['round'] for item in data]
    traffic_acc = [item['traffic_accuracy'] for item in data]
    duration_acc = [item['duration_accuracy'] for item in data]
    bandwidth_acc = [item['bandwidth_accuracy'] for item in data]
    
    all_acc = traffic_acc + duration_acc + bandwidth_acc
    y_min = max(0.0, min(all_acc) - 0.05)
    y_max = min(1.0, max(all_acc) + 0.05)
    
    plt.figure(figsize=(14, 8))
    plt.plot(rounds, traffic_acc, color='teal', label='Traffic Classification', 
            linewidth=2, marker='o', markersize=3)
    plt.plot(rounds, duration_acc, color='cyan', label='Flow Duration Classification', 
            linewidth=2, marker='s', markersize=3)
    plt.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth Classification', 
            linewidth=2, marker='^', markersize=3)
    
    # Phase markers
    plt.axvline(x=90, color='gray', linestyle=':', linewidth=1.5, alpha=0.7, label='Round 90 (Convergence)')
    plt.axvspan(111, 118, alpha=0.15, color='pink', label='D&R-E Phase (111-117)')
    plt.axvspan(118, 121, alpha=0.15, color='yellow', label='Continuity (118-120)')
    plt.axvspan(121, 125, alpha=0.10, color='lightgreen', label='Stabilization (121-125)')
    
    plt.xlabel('Global Rounds', fontsize=12)
    plt.ylabel('Test Accuracy', fontsize=12)
    plt.title(f'CH Compromise After Convergence - Cluster {cluster_id} (Equal Split)', 
             fontsize=14, fontweight='bold')
    plt.legend(loc='best', fontsize=10, ncol=2)
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.ylim(y_min, y_max)
    plt.xlim(0, 125)
    plt.tight_layout()
    plt.show()
    
    print("="*80)
    print(f"CH COMPROMISE - CLUSTER {cluster_id} (EQUAL SPLIT)")
    print("="*80)
    print(f"Total Rounds: {len(rounds)}")
    print(f"\nBefore Compromise (Round 110):")
    if len(traffic_acc) > 109:
        print(f"  Traffic:   {traffic_acc[109]:.4f} ({traffic_acc[109]*100:.2f}%)")
        print(f"  Duration:  {duration_acc[109]:.4f} ({duration_acc[109]*100:.2f}%)")
        print(f"  Bandwidth: {bandwidth_acc[109]:.4f} ({bandwidth_acc[109]*100:.2f}%)")
    print(f"\nAfter Recovery (Round 125):")
    print(f"  Traffic:   {traffic_acc[-1]:.4f} ({traffic_acc[-1]*100:.2f}%)")
    print(f"  Duration:  {duration_acc[-1]:.4f} ({duration_acc[-1]*100:.2f}%)")
    print(f"  Bandwidth: {bandwidth_acc[-1]:.4f} ({bandwidth_acc[-1]*100:.2f}%)")
    print("="*80)
else:
    print("⚠️ Run CH compromise testing cell first")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

cluster_id = 2
if 'test_results' in locals() and 'compromise_after_convergence_per_cluster_equal' in test_results:
    data = test_results['compromise_after_convergence_per_cluster_equal'][cluster_id]
    rounds = [item['round'] for item in data]
    traffic_acc = [item['traffic_accuracy'] for item in data]
    duration_acc = [item['duration_accuracy'] for item in data]
    bandwidth_acc = [item['bandwidth_accuracy'] for item in data]
    
    all_acc = traffic_acc + duration_acc + bandwidth_acc
    y_min = max(0.0, min(all_acc) - 0.05)
    y_max = min(1.0, max(all_acc) + 0.05)
    
    plt.figure(figsize=(14, 8))
    plt.plot(rounds, traffic_acc, color='teal', label='Traffic Classification', 
            linewidth=2, marker='o', markersize=3)
    plt.plot(rounds, duration_acc, color='cyan', label='Flow Duration Classification', 
            linewidth=2, marker='s', markersize=3)
    plt.plot(rounds, bandwidth_acc, color='orange', label='Bandwidth Classification', 
            linewidth=2, marker='^', markersize=3)
    
    # Phase markers
    plt.axvline(x=90, color='gray', linestyle=':', linewidth=1.5, alpha=0.7, label='Round 90 (Convergence)')
    plt.axvspan(111, 118, alpha=0.15, color='pink', label='D&R-E Phase (111-117)')
    plt.axvspan(118, 121, alpha=0.15, color='yellow', label='Continuity (118-120)')
    plt.axvspan(121, 125, alpha=0.10, color='lightgreen', label='Stabilization (121-125)')
    
    plt.xlabel('Global Rounds', fontsize=12)
    plt.ylabel('Test Accuracy', fontsize=12)
    plt.title(f'CH Compromise After Convergence - Cluster {cluster_id} (Equal Split)', 
             fontsize=14, fontweight='bold')
    plt.legend(loc='best', fontsize=10, ncol=2)
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.ylim(y_min, y_max)
    plt.xlim(0, 125)
    plt.tight_layout()
    plt.show()
    
    print("="*80)
    print(f"CH COMPROMISE - CLUSTER {cluster_id} (EQUAL SPLIT)")
    print("="*80)
    print(f"Total Rounds: {len(rounds)}")
    print(f"\nBefore Compromise (Round 110):")
    if len(traffic_acc) > 109:
        print(f"  Traffic:   {traffic_acc[109]:.4f} ({traffic_acc[109]*100:.2f}%)")
        print(f"  Duration:  {duration_acc[109]:.4f} ({duration_acc[109]*100:.2f}%)")
        print(f"  Bandwidth: {bandwidth_acc[109]:.4f} ({bandwidth_acc[109]*100:.2f}%)")
    print(f"\nAfter Recovery (Round 125):")
    print(f"  Traffic:   {traffic_acc[-1]:.4f} ({traffic_acc[-1]*100:.2f}%)")
    print(f"  Duration:  {duration_acc[-1]:.4f} ({duration_acc[-1]*100:.2f}%)")
    print(f"  Bandwidth: {bandwidth_acc[-1]:.4f} ({bandwidth_acc[-1]*100:.2f}%)")
    print("="*80)
else:
    print("⚠️ Run CH compromise testing cell first")