In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix, roc_curve, auc, roc_auc_score
from sklearn.neural_network import MLPClassifier
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import copy
import warnings
warnings.filterwarnings('ignore')

# Seed pour reproductibilit√©
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print(" Imports OK")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

In [None]:
# Configuration du chemin des donn√©es
DATA_DIR = r"/kaggle/input/nslkdd"

# D√©finition des 41 features + label + difficulty
column_names = [
    'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes',
    'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in',
    'num_compromised', 'root_shell', 'su_attempted', 'num_root', 'num_file_creations',
    'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login',
    'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate',
    'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate',
    'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count',
    'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate',
    'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate',
    'dst_host_rerror_rate', 'dst_host_srv_rerror_rate',
    'label', 'difficulty'
]

# Chargement KDDTrain+
try:
    train_data = pd.read_csv(f"{DATA_DIR}/KDDTrain+.txt", header=None, names=column_names)
except:
    train_data = pd.read_csv(f"{DATA_DIR}/KDDTrain+", header=None, names=column_names)

# Chargement KDDTest+
try:
    test_data = pd.read_csv(f"{DATA_DIR}/KDDTest+.txt", header=None, names=column_names)
except:
    test_data = pd.read_csv(f"{DATA_DIR}/KDDTest+", header=None, names=column_names)

# Suppression de la colonne difficulty
train_data = train_data.drop(columns=['difficulty'])
test_data = test_data.drop(columns=['difficulty'])

print(f" Train shape: {train_data.shape}")
print(f" Test shape: {test_data.shape}")
print("\n Premi√®res lignes:")
print(train_data.head())

In [None]:
# Conversion label binaire
train_data['binary_label'] = train_data['label'].apply(lambda x: 0 if x == 'normal' else 1)
test_data['binary_label'] = test_data['label'].apply(lambda x: 0 if x == 'normal' else 1)

# Sauvegarde des labels multi-classe pour le split Non-IID
train_data['attack_type'] = train_data['label'].apply(
    lambda x: 'normal' if x == 'normal' else (
        'DoS' if x in ['back', 'land', 'neptune', 'pod', 'smurf', 'teardrop', 'apache2', 'udpstorm', 'processtable', 'worm'] else (
        'Probe' if x in ['ipsweep', 'nmap', 'portsweep', 'satan', 'mscan', 'saint'] else (
        'R2L' if x in ['ftp_write', 'guess_passwd', 'imap', 'multihop', 'phf', 'spy', 'warezclient', 'warezmaster', 'sendmail', 'named', 'snmpgetattack', 'snmpguess', 'xlock', 'xsnoop', 'httptunnel'] else 'U2R'
        )
    ))
)

print("\n Distribution des classes (Train):")
print(train_data['binary_label'].value_counts())
print(f"Ratio Attack/Normal: {train_data['binary_label'].sum() / len(train_data):.2%}")

print("\n Distribution par type d'attaque:")
print(train_data['attack_type'].value_counts())

In [None]:
#  OneHot Encoding des colonnes cat√©gorielles
categorical_cols = ['protocol_type', 'service', 'flag']

train_encoded = pd.get_dummies(train_data, columns=categorical_cols)
test_encoded = pd.get_dummies(test_data, columns=categorical_cols)

# Aligner les colonnes train/test
train_encoded, test_encoded = train_encoded.align(test_encoded, join='left', axis=1, fill_value=0)

print(f" Shape apr√®s encodage: {train_encoded.shape}")

In [None]:
#  S√©paration features / labels
cols_to_drop = ['label', 'binary_label', 'attack_type']
X_train = train_encoded.drop(columns=cols_to_drop)
y_train = train_encoded['binary_label'].values
attack_types_train = train_data['attack_type'].values

X_test = test_encoded.drop(columns=cols_to_drop)
y_test = test_encoded['binary_label'].values

# Normalisation MinMax
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

print(f" X_train: {X_train.shape}, y_train: {y_train.shape}")
print(f" X_test: {X_test.shape}, y_test: {y_test.shape}")
print(f" Features normalis√©es dans [0, 1]")

In [None]:
NUM_CLIENTS = 5

# Cr√©er les indices pour chaque type
normal_idx = np.where(attack_types_train == 'normal')[0]
dos_idx = np.where(attack_types_train == 'DoS')[0]
probe_idx = np.where(attack_types_train == 'Probe')[0]
r2l_idx = np.where(attack_types_train == 'R2L')[0]
u2r_idx = np.where(attack_types_train == 'U2R')[0]

# Shuffle
np.random.shuffle(normal_idx)
np.random.shuffle(dos_idx)
np.random.shuffle(probe_idx)
np.random.shuffle(r2l_idx)
np.random.shuffle(u2r_idx)

# Split normal data
normal_split = np.array_split(normal_idx, NUM_CLIENTS)

# Distribution Non-IID
client_data = {}

# Client 1: Normal + 70% DoS
dos_split = int(len(dos_idx) * 0.7)
client_data[0] = np.concatenate([normal_split[0], dos_idx[:dos_split]])

# Client 2: Normal + 70% Probe
probe_split = int(len(probe_idx) * 0.7)
client_data[1] = np.concatenate([normal_split[1], probe_idx[:probe_split]])

# Client 3: Normal + 70% R2L
r2l_split = int(len(r2l_idx) * 0.7)
client_data[2] = np.concatenate([normal_split[2], r2l_idx[:r2l_split]])

# Client 4: Normal + 70% U2R
u2r_split = int(len(u2r_idx) * 0.7)
client_data[3] = np.concatenate([normal_split[3], u2r_idx[:u2r_split]])

# Client 5: Mix √©quilibr√© (reste)
remaining = np.concatenate([
    normal_split[4],
    dos_idx[dos_split:],
    probe_idx[probe_split:],
    r2l_idx[r2l_split:],
    u2r_idx[u2r_split:]
])
client_data[4] = remaining

# Shuffle chaque client
for i in range(NUM_CLIENTS):
    np.random.shuffle(client_data[i])

# Afficher statistiques
print("\n Distribution par client (Non-IID):")
print("="*60)
for i in range(NUM_CLIENTS):
    idx = client_data[i]
    n_samples = len(idx)
    n_attacks = y_train[idx].sum()
    ratio = n_attacks / n_samples
    print(f"Client {i+1}: {n_samples:6d} samples | Attacks: {n_attacks:5d} ({ratio:.2%})")

print("="*60)
print(f" Total samples distributed: {sum(len(client_data[i]) for i in range(NUM_CLIENTS))}")

In [None]:
class IDSEnvironment:
    """
    Environnement RL pour d√©tection d'intrusion.
    Simule un environnement o√π chaque √©tat est une connexion r√©seau.
    """
    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.n_samples = len(X)
        self.current_idx = 0
        self.indices = np.arange(self.n_samples)
        np.random.shuffle(self.indices)
        
    def reset(self):
        """R√©initialise l'environnement et retourne le premier √©tat."""
        self.current_idx = 0
        np.random.shuffle(self.indices)
        return self.X[self.indices[self.current_idx]]
    
    def step(self, action):
        """
        Ex√©cute une action et retourne (next_state, reward, done, info).
        
        Args:
            action: 0 (normal) ou 1 (attack)
        
        Returns:
            next_state, reward, done, info
        """
        true_label = self.y[self.indices[self.current_idx]]
        
        # Reward: +1 si correct, -1 si incorrect
        reward = 1.0 if action == true_label else -1.0
        
        # Passer au sample suivant
        self.current_idx += 1
        done = self.current_idx >= self.n_samples
        
        if done:
            next_state = np.zeros_like(self.X[0])
        else:
            next_state = self.X[self.indices[self.current_idx]]
        
        info = {'true_label': true_label}
        
        return next_state, reward, done, info

print(" IDSEnvironment class d√©finie")

In [None]:
class DQN(nn.Module):
    """Deep Q-Network pour d√©tection d'intrusion."""
    
    def __init__(self, input_dim, hidden1=128, hidden2=64, output_dim=2):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.fc3 = nn.Linear(hidden2, output_dim)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

print(" DQN class d√©finie")

In [None]:
class ReplayBuffer:
    """Buffer pour stocker les transitions (s, a, r, s', done)."""
    
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        batch = [self.buffer[i] for i in indices]
        
        states = np.array([b[0] for b in batch])
        actions = np.array([b[1] for b in batch])
        rewards = np.array([b[2] for b in batch])
        next_states = np.array([b[3] for b in batch])
        dones = np.array([b[4] for b in batch])
        
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)

print(" ReplayBuffer class d√©finie")

In [None]:
def train_local_dqn(client_idx, client_indices, global_weights, input_dim, 
                     episodes=5, epsilon=0.1, gamma=0.99, lr=0.001, 
                     batch_size=64, buffer_capacity=10000):
    """
    Entra√Æne un agent DQN local pour un client.
    
    Args:
        client_idx: Index du client
        client_indices: Indices des donn√©es du client
        global_weights: Poids du mod√®le global √† initialiser
        input_dim: Dimension des features
        episodes: Nombre d'√©pisodes d'entra√Ænement
        epsilon: Taux d'exploration
        gamma: Facteur de discount
        lr: Learning rate
        batch_size: Taille du batch
        buffer_capacity: Capacit√© du replay buffer
    
    Returns:
        model_weights: Poids du mod√®le apr√®s entra√Ænement
        metrics: Dictionnaire de m√©triques (loss, reward, accuracy)
    """
    # Donn√©es du client
    X_client = X_train[client_indices]
    y_client = y_train[client_indices]
    
    # Initialisation du mod√®le
    policy_net = DQN(input_dim)
    target_net = DQN(input_dim)
    
    if global_weights is not None:
        policy_net.load_state_dict(global_weights)
        target_net.load_state_dict(global_weights)
    else:
        target_net.load_state_dict(policy_net.state_dict())
    
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    # Replay buffer
    replay_buffer = ReplayBuffer(buffer_capacity)
    
    # Environnement
    env = IDSEnvironment(X_client, y_client)
    
    total_loss = 0
    total_reward = 0
    total_steps = 0
    correct_predictions = 0
    
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        done = False
        
        while not done:
            # Œµ-greedy action selection
            if np.random.rand() < epsilon:
                action = np.random.randint(0, 2)
            else:
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(state).unsqueeze(0)
                    q_values = policy_net(state_tensor)
                    action = q_values.argmax().item()
            
            # Step
            next_state, reward, done, info = env.step(action)
            
            # Stocker dans le buffer
            replay_buffer.push(state, action, reward, next_state, done)
            
            episode_reward += reward
            if reward > 0:
                correct_predictions += 1
            
            state = next_state
            total_steps += 1
            
            # Training step
            if len(replay_buffer) >= batch_size:
                states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
                
                states_t = torch.FloatTensor(states)
                actions_t = torch.LongTensor(actions)
                rewards_t = torch.FloatTensor(rewards)
                next_states_t = torch.FloatTensor(next_states)
                dones_t = torch.FloatTensor(dones)
                
                # Q(s,a)
                current_q = policy_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)
                
                # max Q(s',a')
                with torch.no_grad():
                    next_q = target_net(next_states_t).max(1)[0]
                    target_q = rewards_t + gamma * next_q * (1 - dones_t)
                
                # Loss
                loss = criterion(current_q, target_q)
                
                # Optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
        
        total_reward += episode_reward
    
    # Update target network
    target_net.load_state_dict(policy_net.state_dict())
    
    # Metrics
    avg_loss = total_loss / max(total_steps - batch_size, 1)
    avg_reward = total_reward / episodes
    accuracy = correct_predictions / total_steps if total_steps > 0 else 0
    
    metrics = {
        'loss': avg_loss,
        'reward': avg_reward,
        'accuracy': accuracy,
        'n_samples': len(client_indices)
    }
    
    return policy_net.state_dict(), metrics

print(" train_local_dqn function d√©finie")

In [None]:
def fedavg_aggregate(client_weights, client_metrics):
    """
    Agr√©gation FedAvg: moyenne pond√©r√©e par nombre de samples.
    """
    total_samples = sum(m['n_samples'] for m in client_metrics)
    
    # Initialiser les poids globaux
    global_weights = copy.deepcopy(client_weights[0])
    
    for key in global_weights.keys():
        global_weights[key] = torch.zeros_like(global_weights[key], dtype=torch.float32)
        
        for i, w in enumerate(client_weights):
            weight = client_metrics[i]['n_samples'] / total_samples
            global_weights[key] += w[key].float() * weight
    
    return global_weights

print(" fedavg_aggregate function d√©finie")

In [None]:
def attention_aggregate(client_weights, client_metrics):
    """
    Agr√©gation avec attention dynamique.
    
    attention_multiplier_i = 1 + (1 - accuracy_i)
    attention_i = n_samples_i * attention_multiplier_i
    """
    attention_weights = []
    
    for m in client_metrics:
        accuracy = m['accuracy']
        n_samples = m['n_samples']
        
        # Attention multiplier: plus l'accuracy est faible, plus le poids est √©lev√©
        # Cela permet de donner plus d'importance aux clients qui ont des difficult√©s
        attention_multiplier = 1 + (1 - accuracy)
        attention = n_samples * attention_multiplier
        
        attention_weights.append(attention)
    
    total_attention = sum(attention_weights)
    
    # Normaliser
    attention_weights = [a / total_attention for a in attention_weights]
    
    # Agr√©gation
    global_weights = copy.deepcopy(client_weights[0])
    
    for key in global_weights.keys():
        global_weights[key] = torch.zeros_like(global_weights[key], dtype=torch.float32)
        
        for i, w in enumerate(client_weights):
            global_weights[key] += w[key].float() * attention_weights[i]
    
    print("\n Coefficients d'attention:")
    for i, (att, m) in enumerate(zip(attention_weights, client_metrics)):
        print(f"  Client {i+1}: {att:.4f} (acc={m['accuracy']:.2%}, samples={m['n_samples']})")
    
    return global_weights, attention_weights

print("attention_aggregate function d√©finie")

In [None]:
# Hyperparam√®tres
ROUNDS = 10
EPISODES_PER_ROUND = 3
INPUT_DIM = X_train.shape[1]

print(f" Configuration FL-DQN:")
print(f"  - Rounds: {ROUNDS}")
print(f"  - Episodes par round: {EPISODES_PER_ROUND}")
print(f"  - Input dimension: {INPUT_DIM}")
print(f"  - Nombre de clients: {NUM_CLIENTS}")

In [None]:
# ========== RUN 1: FedAvg ==========
print(" D√âBUT ENTRA√éNEMENT FL-DQN FEDAVG")

global_weights_fedavg = None
history_fedavg = {'accuracy': [], 'loss': [], 'reward': []}

for round_num in range(ROUNDS):
    print(f"\n--- Round {round_num + 1}/{ROUNDS} ---")
    
    client_weights = []
    client_metrics = []
    
    # Entra√Ænement local de chaque client
    for client_idx in range(NUM_CLIENTS):
        weights, metrics = train_local_dqn(
            client_idx=client_idx,
            client_indices=client_data[client_idx],
            global_weights=global_weights_fedavg,
            input_dim=INPUT_DIM,
            episodes=EPISODES_PER_ROUND,
            epsilon=0.1,
            lr=0.001
        )
        client_weights.append(weights)
        client_metrics.append(metrics)
    
    # Agr√©gation FedAvg
    global_weights_fedavg = fedavg_aggregate(client_weights, client_metrics)
    
    # M√©triques moyennes
    avg_accuracy = np.mean([m['accuracy'] for m in client_metrics])
    avg_reward = np.mean([m['reward'] for m in client_metrics])
    
    history_fedavg['accuracy'].append(avg_accuracy)
    history_fedavg['loss'].append(avg_loss)
    history_fedavg['reward'].append(avg_reward)
    
    print(f"  Avg Accuracy: {avg_accuracy:.4f} | Loss: {avg_loss:.4f} | Reward: {avg_reward:.2f}")

print("\n Entra√Ænement FedAvg termin√©")

In [None]:
# ========== RUN 2: Attention ==========
print("üöÄ D√âBUT ENTRA√éNEMENT FL-DQN ATTENTION")

global_weights_attention = None
history_attention = {'accuracy': [], 'loss': [], 'reward': [], 'attention_weights': []}

for round_num in range(ROUNDS):
    print(f"\n--- Round {round_num + 1}/{ROUNDS} ---")
    
    client_weights = []
    client_metrics = []
    
    # Entra√Ænement local de chaque client
    for client_idx in range(NUM_CLIENTS):
        weights, metrics = train_local_dqn(
            client_idx=client_idx,
            client_indices=client_data[client_idx],
            global_weights=global_weights_attention,
            input_dim=INPUT_DIM,
            episodes=EPISODES_PER_ROUND,
            epsilon=0.1,
            lr=0.001
        )
        client_weights.append(weights)
        client_metrics.append(metrics)
    
    # Agr√©gation Attention
    global_weights_attention, attention_w = attention_aggregate(client_weights, client_metrics)
    
    # M√©triques moyennes
    avg_accuracy = np.mean([m['accuracy'] for m in client_metrics])
    avg_loss = np.mean([m['loss'] for m in client_metrics])
    avg_reward = np.mean([m['reward'] for m in client_metrics])
    
    history_attention['accuracy'].append(avg_accuracy)
    history_attention['loss'].append(avg_loss)
    history_attention['reward'].append(avg_reward)
    history_attention['attention_weights'].append(attention_w)
    
    print(f"  Avg Accuracy: {avg_accuracy:.4f} | Loss: {avg_loss:.4f} | Reward: {avg_reward:.2f}")

print("\n Entra√Ænement Attention termin√©")

In [None]:
# Visualisation de l'entra√Ænement
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Accuracy
axes[0].plot(history_fedavg['accuracy'], label='FedAvg', marker='o')
axes[0].plot(history_attention['accuracy'], label='Attention', marker='s')
axes[0].set_xlabel('Round')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Training Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss
axes[1].plot(history_fedavg['loss'], label='FedAvg', marker='o')
axes[1].plot(history_attention['loss'], label='Attention', marker='s')
axes[1].set_xlabel('Round')
axes[1].set_ylabel('Loss')
axes[1].set_title('Training Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Reward
axes[2].plot(history_fedavg['reward'], label='FedAvg', marker='o')
axes[2].plot(history_attention['reward'], label='Attention', marker='s')
axes[2].set_xlabel('Round')
axes[2].set_ylabel('Avg Reward')
axes[2].set_title('Training Reward')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
def evaluate_model(model_weights, X, y, model_name="Model"):
    """
    √âvalue un mod√®le DQN sur un dataset de test.
    """
    model = DQN(X.shape[1])
    model.load_state_dict(model_weights)
    model.eval()
    
    # Pr√©dictions
    with torch.no_grad():
        X_tensor = torch.FloatTensor(X)
        outputs = model(X_tensor)
        y_pred = outputs.argmax(dim=1).numpy()
        y_pred_proba = torch.softmax(outputs, dim=1)[:, 1].numpy()
    
    # M√©triques
    acc = accuracy_score(y, y_pred)
    prec = precision_score(y, y_pred, zero_division=0)
    rec = recall_score(y, y_pred, zero_division=0)
    f1 = f1_score(y, y_pred, zero_division=0)
    auc_score = roc_auc_score(y, y_pred_proba)
    
    # FPR
    cm = confusion_matrix(y, y_pred)
    tn, fp, fn, tp = cm.ravel()
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
    
    results = {
        'model': model_name,
        'accuracy': acc,
        'precision': prec,
        'recall': rec,
        'f1': f1,
        'auc': auc_score,
        'fpr': fpr,
        'y_pred': y_pred,
        'y_pred_proba': y_pred_proba,
        'cm': cm
    }
    
    return results

print(" evaluate_model function d√©finie")

In [None]:
# 1. Baseline centralis√© (MLP sklearn)
print("\n Training centralized baseline (MLP)")
mlp_baseline = MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=50, random_state=SEED)
mlp_baseline.fit(X_train, y_train)
y_pred_mlp = mlp_baseline.predict(X_test)
y_pred_proba_mlp = mlp_baseline.predict_proba(X_test)[:, 1]

acc_mlp = accuracy_score(y_test, y_pred_mlp)
prec_mlp = precision_score(y_test, y_pred_mlp, zero_division=0)
rec_mlp = recall_score(y_test, y_pred_mlp, zero_division=0)
f1_mlp = f1_score(y_test, y_pred_mlp, zero_division=0)
auc_mlp = roc_auc_score(y_test, y_pred_proba_mlp)

cm_mlp = confusion_matrix(y_test, y_pred_mlp)
tn, fp, fn, tp = cm_mlp.ravel()
fpr_mlp = fp / (fp + tn)

results_mlp = {
    'model': 'Centralized MLP',
    'accuracy': acc_mlp,
    'precision': prec_mlp,
    'recall': rec_mlp,
    'f1': f1_mlp,
    'auc': auc_mlp,
    'fpr': fpr_mlp,
    'y_pred': y_pred_mlp,
    'y_pred_proba': y_pred_proba_mlp,
    'cm': cm_mlp
}

print(f"  Accuracy: {acc_mlp:.4f} | F1: {f1_mlp:.4f} | AUC: {auc_mlp:.4f}")

In [None]:
# 2. FL-DQN FedAvg
print("\n  Evaluating FL-DQN FedAvg...")
results_fedavg = evaluate_model(global_weights_fedavg, X_test, y_test, "FL-DQN FedAvg")
print(f"  Accuracy: {results_fedavg['accuracy']:.4f} | F1: {results_fedavg['f1']:.4f} | AUC: {results_fedavg['auc']:.4f}")

In [None]:
# 3. FL-DQN Attention
print("\n Evaluating FL-DQN Attention...")
results_attention = evaluate_model(global_weights_attention, X_test, y_test, "FL-DQN Attention")
print(f"  Accuracy: {results_attention['accuracy']:.4f} | F1: {results_attention['f1']:.4f} | AUC: {results_attention['auc']:.4f}")

In [None]:
# Visualisation: Matrices de confusion
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, res in zip(axes, [results_mlp, results_fedavg, results_attention]):
    sns.heatmap(res['cm'], annot=True, fmt='d', cmap='Blues', ax=ax)
    ax.set_title(f"{res['model']}\nAcc={res['accuracy']:.3f}")
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')

plt.tight_layout()
plt.show()

In [None]:
# Visualisation: Courbes ROC
plt.figure(figsize=(8, 6))

for res in [results_mlp, results_fedavg, results_attention]:
    fpr_roc, tpr, _ = roc_curve(y_test, res['y_pred_proba'])
    plt.plot(fpr_roc, tpr, label=f"{res['model']} (AUC={res['auc']:.3f})")

plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Tableau comparatif
comparison_df = pd.DataFrame([
    {
        'M√©thode': results_mlp['model'],
        'Accuracy': f"{results_mlp['accuracy']:.4f}",
        'Precision': f"{results_mlp['precision']:.4f}",
        'Recall': f"{results_mlp['recall']:.4f}",
        'F1-Score': f"{results_mlp['f1']:.4f}",
        'ROC-AUC': f"{results_mlp['auc']:.4f}",
        'FPR': f"{results_mlp['fpr']:.4f}"
    },
    {
        'M√©thode': results_fedavg['model'],
        'Accuracy': f"{results_fedavg['accuracy']:.4f}",
        'Precision': f"{results_fedavg['precision']:.4f}",
        'Recall': f"{results_fedavg['recall']:.4f}",
        'F1-Score': f"{results_fedavg['f1']:.4f}",
        'ROC-AUC': f"{results_fedavg['auc']:.4f}",
        'FPR': f"{results_fedavg['fpr']:.4f}"
    },
    {
        'M√©thode': results_attention['model'],
        'Accuracy': f"{results_attention['accuracy']:.4f}",
        'Precision': f"{results_attention['precision']:.4f}",
        'Recall': f"{results_attention['recall']:.4f}",
        'F1-Score': f"{results_attention['f1']:.4f}",
        'ROC-AUC': f"{results_attention['auc']:.4f}",
        'FPR': f"{results_attention['fpr']:.4f}"
    }
])

print("\n" + "="*80)
print(" TABLEAU COMPARATIF DES PERFORMANCES")
print("="*80)
print(comparison_df.to_string(index=False))
print("="*80)

###  Analyse des r√©sultats

**Pourquoi l'attention dynamique aide en contexte Non-IID ?**

1. **Gestion de l'h√©t√©rog√©n√©it√©** : L'attention dynamique pond√®re les contributions des clients en fonction de leur performance locale.
   
2. **Compensation des biais** : Les clients avec des donn√©es plus difficiles ou d√©s√©quilibr√©es (faible accuracy) re√ßoivent un poids plus √©lev√©, permettant au mod√®le global de mieux g√©n√©raliser.

3. **Adaptation aux distributions locales** : Contrairement √† FedAvg qui pond√®re uniquement par le nombre d'√©chantillons, l'attention consid√®re aussi la "difficult√©" des donn√©es, ce qui est crucial en Non-IID.

4. **Convergence robuste** : L'agr√©gation par attention permet une convergence plus stable face aux variations de distributions entre clients.

---

## Limites et perspectives

###  Limites de cette impl√©mentation

1. **Simulation RL** : Le RL est simul√© sur un dataset tabulaire statique. Un vrai environnement RL n√©cessiterait des interactions temps r√©el avec le r√©seau.

2. **Communication** : Nous n'avons pas mod√©lis√© les co√ªts de communication ni les d√©lais r√©seau entre clients et serveur.

3. **Attaques FL** : Aucune protection contre les attaques de type poisoning, backdoor ou model inversion.

4. **Classification binaire uniquement** : Pas de d√©tection multi-classe des types d'attaques (DoS, Probe, R2L, U2R).

5. **Scalabilit√©** : L'impl√©mentation est limit√©e √† 5 clients ; un vrai d√©ploiement n√©cessiterait des centaines de clients.

---

###  Perspectives d'am√©lioration

1. **Multi-classe** : √âtendre √† la classification multi-classe pour identifier le type d'attaque.

2. **Vrai r√©seau distribu√©** : D√©ployer sur plusieurs machines avec communication r√©elle (ex: gRPC, WebSockets).

3. **Robust Aggregation** : Impl√©menter des m√©canismes de d√©fense (ex: Krum, Trimmed Mean, Byzantine-robust FL).

4. **Prioritized Experience Replay** : Am√©liorer le DQN avec PER pour un apprentissage plus efficace.

5. **Privacy-preserving FL** : Ajouter du differential privacy et du secure aggregation.

6. **Datasets r√©els** : Tester sur des datasets plus r√©cents (CICIDS2017, CSE-CIC-IDS2018, UNSW-NB15).

7. **Hyperparameter tuning** : Optimiser epsilon, learning rate, buffer size, etc.
