## **Federated Learning for attack detection: 5 nodes sharing weights** 
Includes the cases of the following aggregation functions: 
* Average
* Median
* Krum
* Geometric Median

Saving local and global models. Trained with UNSW-NB15 dataset partitions or Bot-IoT dataset partitions. 

### Static elements for all experiments (execute first)

Preprocessing extracted from https://github.com/polvalls9/Transfer-Learning-Based-Intrusion-Detection-in-5G-and-IoT-Networks.git

In [None]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
import pandas as pd

from tensorflow.keras import datasets, layers, models
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Dense
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# Disable warns
pd.options.mode.chained_assignment = None  # default='warn'

In [None]:
def preprocessing(data): 

    # Select the 'proto' and 'state' values that I want
    data = data.loc[(data['proto'] == 'tcp') | (data['proto'] =='udp') | (data['proto'] =='icmp') | (data['proto'] =='arp') | (data['proto'] =='ipv6-icmp') | (data['proto'] =='igmp') | (data['proto'] =='rarp'), :]
    data = data.loc[(data['state'] == 'RST') | (data['state'] =='REQ') | (data['state'] =='INT') | (data['state'] =='FIN') | (data['state'] =='CON') | (data['state'] =='ECO') | (data['state'] =='ACC') | (data['state'] == 'PAR'), :]

    # Extracting labels 
    data_labels = data[['label']]

    # Drop the invalid features and select interested data features
    data_features=data[['proto','srcip','sport','dstip','dsport','spkts','dpkts','sbytes','dbytes','state','stime','ltime','dur']]

    """PREPROCESSING"""


    # Preprocess IP and ports features
    # IP Source Address
    data_features['srcip'] = data_features['srcip'].apply(lambda x: x.split(".")[-1])
    data_features['srcip'] = data_features['srcip'].apply(lambda x: x.split(":")[-1])
    data_features['srcip'] = data_features['srcip'].apply(lambda x: int(x, 16))


    # IP Destination Address
    data_features['dstip'] = data_features['dstip'].apply(lambda x: x.split(".")[-1])
    data_features['dstip'] = data_features['dstip'].apply(lambda x: x.split(":")[-1])
    data_features['dstip'] = data_features['dstip'].apply(lambda x: int(x, 16))

    # Ports
    data_features['sport'] = data_features['sport'].apply(lambda x: x.replace('0x','') if "0x" in str(x) else x)
    data_features['dsport'] = data_features['dsport'].apply(lambda x: x.replace('0x','') if "0x" in str(x) else x)

    # Convert all ports with 0 decimal, and HEX to DEC
    data_features['sport'] = data_features['sport'].apply(lambda x: str(x)[:-2] if str(x)[-2:] == '.0' else str(x))
    data_features['sport'] = data_features['sport'].apply(lambda x: -1 if str(x).isalpha()==True else int(x,16))

    data_features['dsport'] = data_features['dsport'].apply(lambda x: str(x)[:-2] if str(x)[-2:] == '.0' else str(x))
    data_features['dsport'] = data_features['dsport'].apply(lambda x: -1 if str(x).isalpha()==True else int(x,16))

    # Convert field to int format
    data_features['srcip'] = data_features['srcip'].astype(int)
    data_features['sport'] = data_features['sport'].astype(int)
    data_features['dstip'] = data_features['dstip'].astype(int)
    data_features['dsport'] = data_features['dsport'].astype(int)

    # Convert some fields to logarithmic
    log1p_col = ['dur', 'sbytes', 'dbytes', 'spkts']

    for col in log1p_col:
        data_features[col] = data_features[col].apply(np.log1p)

    # Create a complementary field of attack & Transform to One hot encoding - LABELS
    normal=data_labels['label']
    normal=normal.replace(1,2)
    normal=normal.replace(0,1)
    normal=normal.replace(2,0)

    # Insert the new column in data labels
    data_labels.insert(1, 'normal', normal)
    data_labels = pd.get_dummies(data_labels)

    data_labels = pd.get_dummies(data_labels)

    # Transform to One hot encoding - FEATURES
    data_features=pd.get_dummies(data_features)

    # Value given for the missing columns
    auxCol=0

    # As we are using different datasets that might not have all representations, we are going to detect and add the missing columns 
    # The columns that can have types are: proto and state: need to check if all representations are done 
    state_cols = [col for col in data_features if col.startswith('state_')]
    proto_cols = [col for col in data_features if col.startswith('proto_')]
    
    # Check if all columns are present
    if 'state_PAR' not in state_cols:
        data_features.insert(data_features.shape[1], 'state_PAR', auxCol, True)
    if 'state_ACC' not in state_cols: 
        data_features.insert(data_features.shape[1], 'state_ACC', auxCol, True)
    if 'state_ECO' not in state_cols:
        data_features.insert(data_features.shape[1], 'state_ECO', auxCol, True)
    if 'state_CON' not in state_cols:
        data_features.insert(data_features.shape[1], 'state_CON', auxCol, True)
    if 'state_FIN' not in state_cols:
        data_features.insert(data_features.shape[1], 'state_FIN', auxCol, True)
    if 'state_INT' not in state_cols:
        data_features.insert(data_features.shape[1], 'state_INT', auxCol, True)
    if 'state_REQ' not in state_cols:
        data_features.insert(data_features.shape[1], 'state_REQ', auxCol, True)
    if 'state_RST' not in state_cols:
        data_features.insert(data_features.shape[1], 'state_RST', auxCol, True)
    if 'proto_igmp' not in proto_cols:
        data_features.insert(data_features.shape[1], 'proto_igmp', auxCol, True)
    if 'proto_arp' not in proto_cols:
        data_features.insert(data_features.shape[1], 'proto_arp', auxCol, True)
    if 'proto_icmp' not in proto_cols:
        data_features.insert(data_features.shape[1], 'proto_icmp', auxCol, True)
    if 'proto_udp' not in proto_cols:
        data_features.insert(data_features.shape[1], 'proto_udp', auxCol, True)
    if 'proto_tcp' not in proto_cols:
        data_features.insert(data_features.shape[1], 'proto_tcp', auxCol, True)

    # Normalize all data features
    data_features = StandardScaler().fit_transform(data_features)

    #Add dimension to data features
    data_features = np.expand_dims(data_features, axis=2)
    data_features = np.expand_dims(data_features, axis=3)

    x = data_features
    y = data_labels.to_numpy()

    return x, y

Model definition: 

In [None]:
def build_model():
    input_shape = (24, 1, 1)
    model = models.Sequential()
    model.add(layers.Conv2D(filters=32,  input_shape=input_shape, kernel_size=(1,10), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D(pool_size=(1, 2), padding='same'))
    model.add(layers.Conv2D(filters=64,  input_shape=input_shape, kernel_size=(1,10), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D(pool_size=(1, 2), padding='same'))
    model.add(layers.Flatten())
    model.add(Dense(444, activation='relu'))
    model.add(Dense(2, activation='softmax'))

    return model 

In [None]:
num_nodes = 5 # Number of nodes defined for all cases

Model training extracted and adapted from https://github.com/polvalls9/Transfer-Learning-Based-Intrusion-Detection-in-5G-and-IoT-Networks.git 

## Aggregation function: AVERAGE

In [None]:
def aggregate(w_list): 
    """
    Aggregation function: Average of weights list.

    Args:
    w_list: Contains model weights of each local model

    Returns:
    avg_w: Only one set of weights which consists of the weight average. 
    """
    
    avg_w = np.mean(w_list, axis=0)
    return avg_w

## Aggregation function: GEOMETRIC MEDIAN

In [None]:
def weiszfeld_update(points, tol=1e-5, max_iter=100):
    """
    Compute the geometric median using Weiszfeld's algorithm.

    It minimizes the sum of Euclidean distances to set of points. 
    Iteratively updates estimate of median until convergence. 
    Args:
        points: A list of points (numpy arrays) 
        tol: Tolerance for stopping criterion
        max_iter: Maximum number of iterations

    Returns:
        median: The geometric median of the points
    """
    points = np.array(points) # Points should be in numpy array 
    median = np.mean(points, axis=0) # The median is init to mean of points

    for _ in range(max_iter):
        # Compute distances from current median to all points
        distances = np.linalg.norm(points - median, axis=1)
        nonzero_distances = np.where(distances != 0, distances, np.finfo(float).eps) # Replace zero distances with small val 
        weights = 1 / nonzero_distances 
        new_median = np.sum(points * weights[:, None], axis=0) / np.sum(weights) # Weighted sum of points normalized by sum of weights

        if np.linalg.norm(new_median - median) < tol: # Check convergence 
            return new_median

        median = new_median # Update median for next it

    return median

In [None]:
def aggregate(w_list):
    """
    Aggregates list of local weights using geometric median. 
    

    Args:
        w_list: Contains list of local model weights

    Returns:
        geomed_weights: Aggregated weights 
    """
    # Flatten weights to 1D vector to compute geometric median
    flat_weights = [np.concatenate([w.flatten() for w in weights], axis=0) for weights in w_list]

    # Compute geometric median using Weiszfeld's algorithm
    flat_median = weiszfeld_update(flat_weights)

    # Reshape the flat median back to the original shape
    geomed_weights = [] # Init of return weight list 
    index = 0
    for weight in w_list[0]:
        shape = weight.shape 
        size = np.prod(shape)
        geomed_weights.append(flat_median[index:index + size].reshape(shape))
        index += size

    return geomed_weights


## Aggregation function: MEDIAN

In [None]:
def aggregate(w_list):
    """
    Computes median of local model weights. 

    Args:
        w_list: Contains model weights of each local model

    Returns:
        med_w: Median of weights for the global model 
    """
    
    med_w = [np.median(layer_weights, axis=0) for layer_weights in zip(*w_list)]
    return med_w


## Aggregation function: KRUM

In [None]:

def aggregate(w_list, num_mal=0):
    """
    Krum aggregation function. Method designed to robustly aggregate weights, avoiding the 
    influence of malicious nodes. It selects a set of weights closest to the majority of other nodes.
    Uses pairwise distances. 

    Args:
        w_list: Contains model weights of each local model
        num_mal: Number of malicious nodes. Default: 0 

    Returns:
        kr_weights: Selected weights after Krum aggregation mechanism. Used for global model 
    """

    num_nodes = len(w_list) # Total number of nodes 
    num_to_consider = num_nodes - num_mal - 2 # Number of closest nodes to consider for computing Krum scores

    # Flatten weights to 1D array to compute distances
    flat_weights = [np.concatenate([w.flatten() for w in weights], axis=0) for weights in w_list]

    # Matrix to store pairwise squared Euclidean distances
    distances = np.zeros((num_nodes, num_nodes))
    
    # Compute pairwise squared Euclidean distances
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            dist = np.sum((flat_weights[i] - flat_weights[j]) ** 2)
            distances[i, j] = dist
            distances[j, i] = dist

    # Krum scores for each node 
    krum_scores = []
    for i in range(num_nodes):
        sorted_distances = np.sort(distances[i]) # Sort distances from node to other nodes in ascending order
        score = np.sum(sorted_distances[:num_to_consider]) # Sum distances to closest num_to_consider nodes
        krum_scores.append(score)
    
    krum_index = np.argmin(krum_scores) # Select index of node with min Krum score
    kr_weights = w_list[krum_index] # Weights from node with lowest Krum score as aggregated weihts

    return kr_weights


## TRAINING MODEL
Execute only aggregation function of interest

Trainining datasets are defined based on the dataset partition that we want to train: 

In [None]:
training1 = pd.read_csv('../x_part1.csv')
training2 = pd.read_csv('../x_part2.csv')
training3 = pd.read_csv('../x_part3.csv')
training4 = pd.read_csv('../x_part4.csv')
training5 = pd.read_csv('../x_part5.csv')

In [None]:
global_updates = 5 # Change according to how many iterations wanted
local_epochs = 1 # Change according to how many epochs wanted 
loss_fct = "categorical_crossentropy"
metrics = ['accuracy']

In [None]:
def train_local_model(model, node, x_train, y_train, xval, yval): 
    filepath = '../models/node'+str(node)+'w_x_5_1.keras'
    callbacks = [
            keras.callbacks.EarlyStopping( 
                monitor = 'val_loss', 
                patience = 2 
            ),
            keras.callbacks.ModelCheckpoint(
                filepath = filepath, # file where the checkpoint is saved
                monitor = 'val_loss', # Don't overwrite the saved model unless val_loss is worse
                save_best_only = True)]# Only save model if it is the best
    optimizer = keras.optimizers.Adam(learning_rate=5e-4)
    model.compile(optimizer=optimizer, loss=loss_fct, metrics=metrics)
    history = model.fit(x_train, y_train, epochs=local_epochs, validation_data=(xval, yval), callbacks=callbacks, batch_size=2048)
    return model, history.history['loss'], history.history['accuracy'], history.history['val_loss'], history.history['val_accuracy']

First initialize the global model and for each local dataset split into training and validation subsets:

In [None]:
global_model = build_model()

In [None]:
train1, val1 = train_test_split(training1, test_size=0.2, shuffle = True, random_state=42)
train2, val2 = train_test_split(training2, test_size=0.2, shuffle = True, random_state=42)
train3, val3 = train_test_split(training3, test_size=0.2, shuffle = True, random_state=42)
train4, val4 = train_test_split(training4, test_size=0.2, shuffle = True, random_state=42)
train5, val5 = train_test_split(training5, test_size=0.2, shuffle = True, random_state=42)

Preprocess all subset samples: 

In [None]:
x1, y1 = preprocessing(train1)
x2, y2 = preprocessing(train2)
x3, y3 = preprocessing(train3)
x4, y4 = preprocessing(train4)
x5, y5 = preprocessing(train5)
xv1, yv1 = preprocessing(val1)
xv2, yv2 = preprocessing(val2)
xv3, yv3 = preprocessing(val3)
xv4, yv4 = preprocessing(val4)
xv5, yv5 = preprocessing(val5)

In [None]:
optimizer = keras.optimizers.Adam(learning_rate=5e-4)

Initialization of local models: 

In [None]:
cp0 = build_model()
cp1 = build_model()
cp2 = build_model()
cp3 = build_model()
cp4 = build_model()

In [None]:
global_model.compile(optimizer=optimizer, loss=loss_fct, metrics=metrics)
for i in range(global_updates): 
    w_list = []
    for node in range(num_nodes): 
        if node == 0: 
            x, y = x1, y1
            xv, yv = xv1, yv1
            cp0.set_weights(global_model.get_weights())
        elif node == 1:
            x, y = x2, y2
            xv, yv = xv2, yv2
            cp1.set_weights(global_model.get_weights())
        elif node == 2:
            x, y = x3, y3
            xv, yv = xv3, yv3
            cp2.set_weights(global_model.get_weights())
        elif node == 3:
            x, y = x4, y4
            xv, yv = xv4, yv4
            cp3.set_weights(global_model.get_weights())
        elif node == 4: 
            x, y = x5, y5
            xv, yv = xv5, yv5
            cp4.set_weights(global_model.get_weights())
        cp = [cp0, cp1, cp2, cp3, cp4]
        local_model, local_loss, local_acc, local_val_loss, local_val_acc = train_local_model(cp[node], node, x, y, xv, yv)
        w_list.append(local_model.get_weights()) # get local weights 

    avg_w = aggregate(w_list) # aggregate all local weights for this iteration 
    global_model.set_weights(avg_w) # apply weights to global model 


global_model.save('../models/w_x_5_1.hdf5') # save global model 