# Federated IDS (10 clients) with TenSEAL (CKKS) encrypted weight aggregation



In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
import matplotlib.pyplot as plt

try:
    import tenseal as ts
except Exception as e:
    print('tenseal not available. Install tenseal (pip install tenseal) to use encrypted aggregation.')
    ts = None


tenseal not available. Install tenseal (pip install tenseal) to use encrypted aggregation.


In [None]:
import os
import glob
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler


DATA_DIR = '/content'

csv_files = glob.glob(os.path.join(DATA_DIR, '*.csv'))
if not csv_files:
    raise FileNotFoundError(f'No CSV files found in {DATA_DIR}. Please upload them and re-run.')

print(f'Found {len(csv_files)} CSV files. Reading and concatenating...')
dfs = [pd.read_csv(file, low_memory=False) for file in csv_files]
df = pd.concat(dfs, ignore_index=True)
print('Raw combined shape:', df.shape)


if 'Label' in df.columns:
    label_col = 'Label'
elif ' Label' in df.columns:
    label_col = ' Label'
else:
    label_col = df.columns[-1]
    print(' Using last column as label:', label_col)

non_numeric = df.select_dtypes(exclude=[np.number]).columns.tolist()
non_numeric = [c for c in non_numeric if c != label_col]
df = df.drop(columns=non_numeric, errors='ignore')

unique_labels = df[label_col].unique()
label_mapping = {}
current_index = 0
for lbl in unique_labels:
    lbl_str = str(lbl).strip().upper()
    if lbl_str.startswith("BENIGN"):
        label_mapping[lbl] = 0
    else:
        current_index += 1
        label_mapping[lbl] = current_index

df[label_col] = df[label_col].map(label_mapping)

print("Fine-grained label distribution:")
print(df[label_col].value_counts())

df = df.dropna(subset=[label_col])
print('After cleaning shape:', df.shape)

X = df.drop(columns=[label_col]).values
y = df[label_col].values


X = np.where(np.isinf(X), np.nan, X)
col_means = np.nanmean(X, axis=0)
inds = np.where(np.isnan(X))
X[inds] = np.take(col_means, inds[1])

scaler = StandardScaler()
X = scaler.fit_transform(X)
num_classes = len(np.unique(y))
print(' Features shape:', X.shape)
print(' Labels shape:', y.shape)


Found 5 CSV files. Reading and concatenating...
Raw combined shape: (628218, 79)
Fine-grained label distribution:
 Label
0    508840
1     69636
4     48179
5      1537
3        21
2         5
Name: count, dtype: int64
After cleaning shape: (628218, 79)
✅ Features shape: (628218, 78)
✅ Labels shape: (628218,)


In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from collections import Counter
import numpy as np
import random
X_train_full, X_test_full, y_train_full, y_test_full = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

global_test_dataset = TensorDataset(torch.tensor(X_test_full, dtype=torch.float32),
                                    torch.tensor(y_test_full, dtype=torch.long))
global_test_loader = DataLoader(global_test_dataset, batch_size=256, shuffle=False)

print(f"Total samples: {len(X)}")
print(f"Global training samples (80%): {len(X_train_full)}")
print(f"Global test samples (20%): {len(X_test_full)}")


num_clients = 10
num_priority_clients = 4
min_samples_for_split = 2
alpha = 0.1 # Dirichlet parameter

unique_classes = np.unique(y_train_full)
num_classes = len(unique_classes)
class_indices = {cls: np.where(y_train_full == cls)[0] for cls in unique_classes}

min_global_samples_needed = num_priority_clients * min_samples_for_split
splittable_classes = {cls for cls, indices in class_indices.items() if len(indices) >= min_global_samples_needed}
rare_classes = {cls for cls in unique_classes if cls not in splittable_classes}

print(f"\nClasses identified as 'splittable' (>= {min_global_samples_needed} samples): {splittable_classes}")
print(f"Classes identified as 'rare' (< {min_global_samples_needed} samples): {rare_classes}")
t
client_indices = {i: [] for i in range(num_clients)}
assigned_indices_mask = np.zeros(len(X_train_full), dtype=bool)

priority_client_ids = random.sample(range(num_clients), num_priority_clients)
print(f"Prioritizing clients for local testing: {priority_client_ids}")

for cls in splittable_classes:
    indices_for_class = class_indices[cls]
    np.random.shuffle(indices_for_class)

    assigned_count = 0
    for client_id in priority_client_ids:
        # Assign 'min_samples_for_split' indices to this priority client
        indices_to_assign = indices_for_class[assigned_count : assigned_count + min_samples_for_split]
        client_indices[client_id].extend(indices_to_assign)
        assigned_indices_mask[indices_to_assign] = True # Mark these as assigned
        assigned_count += min_samples_for_split

remaining_global_indices = np.where(~assigned_indices_mask)[0]
if len(remaining_global_indices) > 0:
    X_remaining = X_train_full[remaining_global_indices]
    y_remaining = y_train_full[remaining_global_indices]

    remaining_class_indices = {cls: np.where(y_remaining == cls)[0] for cls in unique_classes}

    for cls in unique_classes:
        cls_remaining_indices = remaining_class_indices.get(cls, np.array([], dtype=int))
        if len(cls_remaining_indices) == 0:
            continue

        np.random.shuffle(cls_remaining_indices)
        proportions = np.random.dirichlet(np.ones(num_clients) * alpha)

        counts = (proportions * len(cls_remaining_indices)).astype(int)
        while counts.sum() < len(cls_remaining_indices):
            counts[np.random.choice(num_clients)] += 1

        current_idx = 0
        for client_id in range(num_clients):
            num_samples = counts[client_id]
            if num_samples > 0:

                original_indices_to_add = remaining_global_indices[cls_remaining_indices[current_idx : current_idx + num_samples]]
                client_indices[client_id].extend(original_indices_to_add)
                current_idx += num_samples
else:
     print("Warning: All data was assigned during prioritization. No remaining data for Dirichlet distribution.")

clients = []
clients_with_local_test = 0

print("\n--- Creating Local Train/Test Splits for each client ---")
for i in range(num_clients):

    indices = np.unique(np.array(client_indices[i], dtype=int))

    if len(indices) == 0:
        print(f"Client {i+1}: Received no data. Skipping.")
        continue

    Xc, yc = X_train_full[indices], y_train_full[indices]

    label_counts = Counter(yc)
    min_class_count = min(label_counts.values()) if label_counts else 0

    if min_class_count < min_samples_for_split:
        Xc_train, yc_train = Xc, yc
        Xc_test, yc_test = Xc[0:1], yc[0:1]
        print(f"Client {i+1}: Total {len(yc)} samples (Too few samples in class {min(label_counts, key=label_counts.get)} ({min_class_count}) to stratify, using all for training)")
    else:

        Xc_train, Xc_test, yc_train, yc_test = train_test_split(
            Xc, yc, test_size=0.2, random_state=42, stratify=yc
        )
        print(f"Client {i+1}: Total {len(yc)} samples -> Train: {len(Xc_train)}, Test: {len(Xc_test)} (Local test created)")
        clients_with_local_test += 1

    clients.append((Xc_train, yc_train, Xc_test, yc_test))

print(f"\nSuccessfully created data splits for {len(clients)} clients.")
print(f"Number of clients with a meaningful local test set: {clients_with_local_test}")

if len(clients) != num_clients:
     print(f"Warning: Expected {num_clients} clients, but only created {len(clients)}.")
if clients_with_local_test < num_priority_clients:
     print(f"Warning: Targeted {num_priority_clients} clients for local testing, but only achieved {clients_with_local_test}. "
           "This can happen with highly skewed data after the Dirichlet step.")

Total samples: 628218
Global training samples (80%): 502574
Global test samples (20%): 125644

Classes identified as 'splittable' (>= 8 samples): {np.int64(0), np.int64(1), np.int64(3), np.int64(4), np.int64(5)}
Classes identified as 'rare' (< 8 samples): {np.int64(2)}
Prioritizing clients for local testing: [8, 2, 4, 7]

--- Creating Local Train/Test Splits for each client ---
Client 1: Total 151644 samples (Too few samples in class 5 (1) to stratify, using all for training)
Client 2: Total 23907 samples (Too few samples in class 3 (1) to stratify, using all for training)
Client 3: Total 23133 samples -> Train: 18506, Test: 4627 (Local test created)
Client 4: Total 214 samples -> Train: 171, Test: 43 (Local test created)
Client 5: Total 245929 samples -> Train: 196743, Test: 49186 (Local test created)
Client 6: Total 152 samples -> Train: 121, Test: 31 (Local test created)
Client 7: Total 1566 samples -> Train: 1252, Test: 314 (Local test created)
Client 8: Total 42614 samples -> Trai

In [None]:
def model_to_vector(state_dict):
    vec = []
    shapes = {}
    for k, v in state_dict.items():
        arr = v.cpu().numpy().ravel()
        shapes[k] = v.shape
        vec.append(arr)
    flat = np.concatenate(vec).astype(np.float64)
    return flat, shapes

def vector_to_model(state_dict_template, flat_vec, shapes):
    new_state = {}
    ptr = 0
    for k in state_dict_template.keys():
        num = int(np.prod(shapes[k]))
        slice_ = flat_vec[ptr:ptr+num]
        new_state[k] = torch.tensor(slice_.reshape(shapes[k]), dtype=state_dict_template[k].dtype)
        ptr += num
    return new_state

def create_tenseal_context():
    if ts is None:
        raise RuntimeError('tenseal not installed')
    ctx = ts.context(ts.SCHEME_TYPE.CKKS, poly_modulus_degree=8192, coeff_mod_bit_sizes=[60, 40, 40, 60])
    ctx.generate_galois_keys()
    ctx.global_scale = 2**40
    return ctx

def encrypt_vector(ctx, vec):
    return ts.ckks_vector(ctx, vec)

def decrypt_vector(enc_vec):
    return np.array(enc_vec.decrypt())


In [None]:
def train_local(model, X_train, y_train, epochs=3, batch_size=64, lr=1e-3, device='cpu'):
    model = model.to(device)
    model.train()
    opt = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    ds = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long))
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)
    for e in range(epochs):
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            out = model(xb)
            loss = criterion(out, yb)
            loss.backward()
            opt.step()
    return model.state_dict()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import tenseal as ts
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from collections import Counter # Added for class counts


class IDS1DCNN(nn.Module):
    """A simple 1D CNN for Intrusion Detection."""
    def __init__(self, input_dim=78, num_classes=6): # Default input_dim to 78
        super(IDS1DCNN, self).__init__()
        # Feature extractor with Conv1d layers
        self.feature_extractor = nn.Sequential(
            nn.Conv1d(1, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Conv1d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(64 * input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = x.unsqueeze(1)

        features = self.feature_extractor(x)
        features_flat = features.view(features.size(0), -1)
        logits = self.classifier(features_flat)
        return logits

def get_model(input_dim, num_classes):
    return IDS1DCNN(input_dim, num_classes)

def evaluate_model(model, loader, device='cpu'):
    """Evaluates a PyTorch model on a given DataLoader."""
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb)
            preds = torch.argmax(out, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(yb.cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    prec = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    rec = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    return acc, f1, prec, rec
#
def encrypt_vector_chunks(context, vec, chunk_size=None):
    if chunk_size is None:

        chunk_size = 8192
    enc_chunks = []
    for i in range(0, len(vec), chunk_size):
        chunk = vec[i:i+chunk_size]
        enc_chunks.append(ts.ckks_vector(context, chunk))
    return enc_chunks

def decrypt_vector_chunks(enc_chunks):
    vec = np.concatenate([np.array(c.decrypt()) for c in enc_chunks])
    return vec

def federated_train_with_encryption(clients, input_dim, num_classes,
                                    global_test_loader, # Added argument
                                    rounds=3, local_epochs=2, batch_size=64, device='cpu'):

    print(f"Starting Federated Learning with Encryption — Rounds: {rounds}, Clients: {len(clients)}")

    global_model = get_model(input_dim, num_classes).to(device)

    context = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=16384,
        coeff_mod_bit_sizes=[60, 40, 40, 60]
    )
    context.global_scale = 2**40
    context.generate_galois_keys()

    for r in range(rounds):
        print(f"\nRound {r+1}/{rounds} — clients: {len(clients)}")
        encrypted_updates = []

        # LOCAL TRAINING & LOCAL EVALUATION
        for cid, (Xc_train, yc_train, Xc_test, yc_test) in enumerate(clients):
            print(f" Client {cid}: training on {len(Xc_train)} samples")


            local_model = get_model(input_dim, num_classes).to(device)
            local_model.load_state_dict(global_model.state_dict())
            #  Use a lower learning rate for stability, especially with CNNs
            optimizer = optim.Adam(local_model.parameters(), lr=0.0001)

            criterion = nn.CrossEntropyLoss()


            train_dataset = TensorDataset(torch.tensor(Xc_train, dtype=torch.float32),
                                          torch.tensor(yc_train, dtype=torch.long))
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

            local_test_dataset = TensorDataset(torch.tensor(Xc_test, dtype=torch.float32),
                                               torch.tensor(yc_test, dtype=torch.long))
            local_test_loader = DataLoader(local_test_dataset, batch_size=256, shuffle=False)

            local_model.train()
            for epoch in range(local_epochs):
                for Xbatch, ybatch in train_loader:
                    Xbatch, ybatch = Xbatch.to(device), ybatch.to(device)
                    optimizer.zero_grad()
                    out = local_model(Xbatch)
                    loss = criterion(out, ybatch)
                    loss.backward()
                    optimizer.step()

            if len(local_test_loader.dataset) <= 10:
                print(f"  Client {cid}: Local test set too small or unbalanced for meaningful metrics.")
            else:

                loc_acc, loc_f1, _, _ = evaluate_model(local_model, local_test_loader, device)
                print(f"  Client {cid} Local Metrics -> Acc: {loc_acc:.4f}, F1: {loc_f1:.4f}")

            weights_vector = torch.cat([p.data.view(-1) for p in local_model.parameters()]).cpu().numpy()
            enc_chunks = encrypt_vector_chunks(context, weights_vector)
            encrypted_updates.append(enc_chunks)

        num_chunks = len(encrypted_updates[0])
        aggregated_chunks = []
        for chunk_idx in range(num_chunks):
            chunk_sum = encrypted_updates[0][chunk_idx]
            for client_idx in range(1, len(encrypted_updates)):
                chunk_sum += encrypted_updates[client_idx][chunk_idx]
            chunk_avg = chunk_sum * (1.0 / len(clients))
            aggregated_chunks.append(chunk_avg)

        decrypted_avg = decrypt_vector_chunks(aggregated_chunks)
        print(" Decrypted aggregated vector sample (first 10 elements):", np.round(decrypted_avg[:10], 6))

        idx = 0
        new_state = {}
        for name, param in global_model.state_dict().items():
            size = param.numel()
            # Ensure tensor shapes match during loading
            new_state[name] = torch.tensor(decrypted_avg[idx:idx+size], dtype=param.dtype).view(param.shape)
            idx += size
        global_model.load_state_dict(new_state)
        print(f" Round {r+1} global model updated.")

        print(f" Round {r+1}: Evaluating new global model on GLOBAL test set...")
        glob_acc, glob_f1, glob_prec, glob_rec = evaluate_model(global_model, global_test_loader, device)
        print(f" Round {r+1} Global Metrics -> Acc: {glob_acc:.4f}, F1: {glob_f1:.4f}, Prec: {glob_prec:.4f}, Rec: {glob_rec:.4f}")

    return global_model

In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import torch

if 'clients' in globals() and len(clients) >= 1:

    input_dim = clients[0][0].shape[1]
    num_classes = len(np.unique(y))
    try:
        global_model = federated_train_with_encryption(
            clients,
            input_dim,
            num_classes=num_classes,
            global_test_loader=global_test_loader, # <-- Pass the global test loader
            rounds=3,
            local_epochs=2,
            device='cpu'
        )
    except Exception as e:
        print('An error occurred during training:', e)
        raise e

    print("\n\n=== FEDERATED TRAINING COMPLETE ===")


else:
    print('Clients not prepared. Ensure dataset was loaded and clients created in Cell 5.')

Starting Federated Learning with Encryption — Rounds: 3, Clients: 10

Round 1/3 — clients: 10
 Client 0: training on 151644 samples
  Client 0: Local test set too small or unbalanced for meaningful metrics.
 Client 1: training on 23907 samples
  Client 1: Local test set too small or unbalanced for meaningful metrics.
 Client 2: training on 18506 samples
  Client 2 Local Metrics -> Acc: 0.9879, F1: 0.9877
 Client 3: training on 171 samples
  Client 3 Local Metrics -> Acc: 1.0000, F1: 1.0000
 Client 4: training on 196743 samples
  Client 4 Local Metrics -> Acc: 1.0000, F1: 0.9999
 Client 5: training on 121 samples
  Client 5 Local Metrics -> Acc: 1.0000, F1: 1.0000
 Client 6: training on 1252 samples
  Client 6 Local Metrics -> Acc: 0.8854, F1: 0.8822
 Client 7: training on 34091 samples
  Client 7 Local Metrics -> Acc: 0.9885, F1: 0.9884
 Client 8: training on 279 samples
  Client 8 Local Metrics -> Acc: 0.7857, F1: 0.6914
 Client 9: training on 13066 samples
  Client 9: Local test set 