In [1]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
import requests
from io import StringIO
import matplotlib.pyplot as plt
from collections import OrderedDict
import copy

In [2]:
torch.manual_seed(42)
np.random.seed(42)

In [3]:
NUM_CLIENTS = 5
BYZANTINE_CLIENTS = 1
BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 5
GLOBAL_ROUNDS = 5

In [4]:
def load_nsl_kdd():
    url = "https://raw.githubusercontent.com/defcom17/NSL_KDD/master/KDDTrain%2B.csv"
    try:
        response = requests.get(url)
        response.raise_for_status()
        
        columns = [
            'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes',
            'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in',
            'num_compromised', 'root_shell', 'su_attempted', 'num_root', 'num_file_creations',
            'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login', 'is_guest_login',
            'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate',
            'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count',
            'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate',
            'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate',
            'dst_host_rerror_rate', 'dst_host_srv_rerror_rate', 'attack_type', 'difficulty_level'
        ]
        
        df = pd.read_csv(StringIO(response.text), header = None, names = columns)
        return df
    except Exception as e:
        print(f"Error loading dataset: {e}")
        try:
            df = pd.read_csv("KDDTrain+.csv", header = None, names = columns)
            return df
        except:
            return None

In [14]:
def preprocess_data(df):
    df = df.drop('difficulty_level', axis = 1)

    categorical_cols = ['protocol_type', 'service', 'flag']
    
    df_encoded = pd.get_dummies(df, columns = categorical_cols)
    
    attack_mapping = {
        'normal': 'normal',
        'back': 'dos', 'land': 'dos', 'neptune': 'dos', 'pod': 'dos', 
        'smurf': 'dos', 'teardrop': 'dos',
        'buffer_overflow': 'u2r', 'loadmodule': 'u2r', 'perl': 'u2r', 
        'rootkit': 'u2r',
        'ftp_write': 'r2l', 'guess_passwd': 'r2l', 'imap': 'r2l', 
        'multihop': 'r2l', 'phf': 'r2l', 'spy': 'r2l', 'warezclient': 'r2l', 
        'warezmaster': 'r2l',
        'ipsweep': 'probe', 'nmap': 'probe', 'portsweep': 'probe', 
        'satan': 'probe'
    }

    df_encoded['attack_category'] = df['attack_type'].map(lambda x: attack_mapping.get(x, 'other'))
    
    label_encoder = LabelEncoder()
    df_encoded['attack_category'] = label_encoder.fit_transform(df_encoded['attack_category'])
    
    attack_category_mapping = {i: category for i, category in enumerate(label_encoder.classes_)}
    
    df_encoded = df_encoded.drop('attack_type', axis=1)
    
    X = df_encoded.drop('attack_category', axis=1)
    y = df_encoded['attack_category']
    
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size = 0.2, random_state = 42, stratify = y)
    
    print(f"Number of features: {X_train.shape[1]}")
    print(f"Number of classes: {len(np.unique(y_train))}")
    
    return X_train, X_test, y_train, y_test, attack_category_mapping

class NSLKDDDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.FloatTensor(features)
        self.labels = torch.LongTensor(labels.values)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

In [6]:
def split_data_for_clients(X, y, num_clients):
    client_data_size = len(X) // num_clients
    
    client_data = []
    
    # splitting data amongst clients
    for i in range(num_clients):
        start_idx = i * client_data_size
        end_idx = (i + 1) * client_data_size if i < num_clients - 1 else len(X)
        
        client_X = X[start_idx: end_idx]
        client_y = y.iloc[start_idx: end_idx]
        
        client_data.append((client_X, client_y))
    
    return client_data

In [7]:
class IDSModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(IDSModel, self).__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )
    
    def forward(self, x):
        return self.network(x)

In [8]:
def apply_byzantine_attack(model, attack_type = 'sign_flip'): # label/sign flipping by default
    with torch.no_grad():
        attacked_state_dict = OrderedDict()
        
        for key, param in model.state_dict().items():
            if 'weight' in key or 'bias' in key:
                if attack_type == 'sign_flip':
                    # label/sign flipping
                    attacked_state_dict[key] = -param.clone()
                elif attack_type == 'random':
                    # random value replacement
                    attacked_state_dict[key] = torch.randn_like(param)
                elif attack_type == 'constant':
                    #  large constant replacement
                    attacked_state_dict[key] = torch.ones_like(param) * 10
            else:
                attacked_state_dict[key] = param.clone()
        
        model_copy = copy.deepcopy(model)
        model_copy.load_state_dict(attacked_state_dict)
        
    return model_copy

In [9]:
def krum_aggregation(client_models, num_byzantine, global_model):
    num_clients = len(client_models)
    num_to_select = num_clients - num_byzantine - 2
    
    if num_to_select <= 0:
        num_to_select = 1
    
    client_params = []
    for model in client_models:
        params = []
        for param in model.parameters():
            params.append(param.data.view(-1))
        client_params.append(torch.cat(params))
    
    distances = torch.zeros(num_clients, num_clients)
    for i in range(num_clients):
        for j in range(i + 1, num_clients):
            distances[i, j] = torch.norm(client_params[i] - client_params[j]) ** 2
            distances[j, i] = distances[i, j]

    scores = torch.zeros(num_clients)
    for i in range(num_clients):
        closest_distances = torch.topk(distances[i], num_clients - num_to_select, largest = False).values
        scores[i] = torch.sum(closest_distances)
    
    selected_client = torch.argmin(scores).item()
    
    global_model.load_state_dict(client_models[selected_client].state_dict())
    
    return global_model

In [10]:
def train_client_model(model, train_loader, epochs, learning_rate, client_id, is_byzantine = False):
    client_model = copy.deepcopy(model)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(client_model.parameters(), lr = learning_rate)
    
    client_model.train()
    
    # simply returns model if it is Byzantine
    if is_byzantine:
        return apply_byzantine_attack(client_model)
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            optimizer.zero_grad()

            # forward pass
            outputs = client_model(inputs)
            loss = criterion(outputs, labels)

            # backward pass + optimize
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        
        print(f"Client {client_id} - Epoch {epoch + 1} / {epochs} - Loss: {epoch_loss:.4f} - Accuracy: {epoch_acc:.2f}%")
    
    return client_model

In [11]:
def eval_model(model, test_loader):
    model.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy: .3f}%")
    
    return accuracy

In [12]:
def run_federated_learning():
    df = load_nsl_kdd()
    if df is None:
        print("Error: Failed to load dataset.")
        return
    
    X_train, X_test, y_train, y_test, attack_mapping = preprocess_data(df)
    
    client_data = split_data_for_clients(X_train, y_train, NUM_CLIENTS)

    test_dataset = NSLKDDDataset(X_test, pd.Series(y_test))
    test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = False)
    
    client_loaders = []
    for X_client, y_client in client_data:
        client_dataset = NSLKDDDataset(X_client, y_client)
        client_loader = DataLoader(client_dataset, batch_size=BATCH_SIZE, shuffle=True)
        client_loaders.append(client_loader)
    
    num_features = X_train.shape[1]
    num_classes = len(attack_mapping)
    global_model = IDSModel(num_features, num_classes)
    
    round_accuracies = []
    
    # federated learning rounds
    for round_num in range(GLOBAL_ROUNDS):
        print(f"\nRound {round_num + 1} / {GLOBAL_ROUNDS}")
        
        client_models = []
        
        for client_id in range(NUM_CLIENTS):
            print(f"\nTraining Client {client_id + 1} / {NUM_CLIENTS}")

            is_byzantine = client_id < BYZANTINE_CLIENTS
            
            client_model = train_client_model(
                global_model, 
                client_loaders[client_id],
                EPOCHS,
                LEARNING_RATE,
                client_id,
                is_byzantine
            )
            
            client_models.append(client_model)
            
        global_model = krum_aggregation(client_models, BYZANTINE_CLIENTS, global_model)
        
        accuracy = eval_model(global_model, test_loader)
        round_accuracies.append(accuracy)

In [15]:
if __name__ == "__main__":
    run_federated_learning()

Number of features: 122
Number of classes: 5

Round 1 / 5 -----

Training Client 1 / 5

Training Client 2 / 5
Client 1 - Epoch 1 / 5 - Loss: 0.2343 - Accuracy: 93.80%
Client 1 - Epoch 2 / 5 - Loss: 0.0742 - Accuracy: 97.60%
Client 1 - Epoch 3 / 5 - Loss: 0.0488 - Accuracy: 98.36%
Client 1 - Epoch 4 / 5 - Loss: 0.0378 - Accuracy: 98.80%
Client 1 - Epoch 5 / 5 - Loss: 0.0342 - Accuracy: 98.93%

Training Client 3 / 5
Client 2 - Epoch 1 / 5 - Loss: 0.2399 - Accuracy: 93.29%
Client 2 - Epoch 2 / 5 - Loss: 0.0692 - Accuracy: 97.66%
Client 2 - Epoch 3 / 5 - Loss: 0.0485 - Accuracy: 98.37%
Client 2 - Epoch 4 / 5 - Loss: 0.0369 - Accuracy: 98.74%
Client 2 - Epoch 5 / 5 - Loss: 0.0355 - Accuracy: 98.84%

Training Client 4 / 5
Client 3 - Epoch 1 / 5 - Loss: 0.2299 - Accuracy: 93.85%
Client 3 - Epoch 2 / 5 - Loss: 0.0670 - Accuracy: 97.74%
Client 3 - Epoch 3 / 5 - Loss: 0.0434 - Accuracy: 98.57%
Client 3 - Epoch 4 / 5 - Loss: 0.0360 - Accuracy: 98.91%
Client 3 - Epoch 5 / 5 - Loss: 0.0327 - Accura