In [2]:
# Import required libraries
import pandas as pd
import numpy as np
import pickle
from keras.models import load_model
from sklearn.preprocessing import StandardScaler
from scapy.all import sniff, IP, TCP, UDP
import queue
import threading
import time
import warnings

# Suppress unwanted warnings
warnings.filterwarnings("ignore")

# Load trained model and encoders
model = load_model("IDS_Detect.h5")
with open("label_encoders.pkl", "rb") as f:
    label_encoders = pickle.load(f)

# Initialize global variables
packet_queue = queue.Queue()
packet_batch = []
batch_size = 100
ip_duration_tracker = {}
service_src_connections = {}
service_dst_connections = {}
ip_packet_tracker = {}
source_avg_tracker = {}
dest_avg_tracker = {}
dest_connections = {}

# Port-to-service mapping
port_to_service = {
    80: 'http', 443: 'ssl', 21: 'ftp', 20: 'ftp-data', 53: 'dns', 67: 'dhcp',
    68: 'dhcp', 110: 'pop3', 1812: 'radius', 25: 'smtp', 161: 'snmp',
    22: 'ssh', 194: 'irc'
}

# Function to safely transform unseen labels in categorical features
def safe_transform_label(encoder, column):
    unknown_val = encoder.transform(["Unknown"])[0] if "Unknown" in encoder.classes_ else 0
    transformed = [
        encoder.transform([x])[0] if x in encoder.classes_ else unknown_val for x in column
    ]
    return transformed

# Function to preprocess packet data
def preprocess_packet_data(packet_df, encoders):
    packet_df = packet_df.copy()
    # Apply label encoding for categorical columns
    for col in ["proto", "service", "state"]:
        if col in encoders and col in packet_df.columns:
            packet_df[col] = safe_transform_label(encoders[col], packet_df[col].astype(str))

    # Fill missing values to handle any NaNs
    packet_df.fillna(0, inplace=True)

    # Standardize numerical features within the batch
    numerical_cols = [
        "dur", "sbytes", "dbytes", "sttl", "dttl", "sload", "dload",
        "spkts", "dpkts", "ct_srv_src", "ct_srv_dst", "smean", "dmean", "ct_dst_ltm"
    ]
    scaler = StandardScaler()
    existing_numerical_cols = [col for col in numerical_cols if col in packet_df.columns]
    if existing_numerical_cols:
        packet_df[existing_numerical_cols] = scaler.fit_transform(packet_df[existing_numerical_cols])

    return packet_df

# Packet callback function
def packet_callback(packet):
    global packet_batch
    try:
        src_ip = packet[IP].src if IP in packet else "Unknown"
        dst_ip = packet[IP].dst if IP in packet else "Unknown"

        service = "-"
        if TCP in packet or UDP in packet:
            src_port = packet.sport if TCP in packet else packet[UDP].sport
            dst_port = packet.dport if TCP in packet else packet[UDP].dport
            service = port_to_service.get(src_port, port_to_service.get(dst_port, "-"))

        # Fix for sttl and dttl to use separate values
        sttl = packet[IP].ttl if IP in packet else 64  # Source TTL
        dttl = packet[IP].ttl if IP in packet else 64  # Destination TTL (if applicable)
        
        # Modify logic to consider both source and destination IPs separately
        if IP in packet:
            sttl = packet[IP].ttl  # TTL value for source IP
            if "dst" in packet[IP].fields:
                dttl = packet[IP].ttl  # TTL value seen by the destination device


        state = ""
        if TCP in packet:
            flags = packet[TCP].flags
            if flags & 0x02:  # SYN
                state = "REQ"
            elif flags & 0x12:  # SYN-ACK
                state = "ACC"
            elif flags & 0x01:  # FIN
                state = "FIN"
            elif flags & 0x04:  # RST
                state = "RST"
            elif flags & 0x10:  # ACK
                state = "CON"

        service_src_key = (service, src_ip)
        service_dst_key = (service, dst_ip)
        pair_key = (src_ip, dst_ip)
        reverse_pair_key = (dst_ip, src_ip)
        current_time = time.time()

        service_src_connections[service_src_key] = service_src_connections.get(service_src_key, 0) + 1
        service_dst_connections[service_dst_key] = service_dst_connections.get(service_dst_key, 0) + 1
        dest_connections[dst_ip] = dest_connections.get(dst_ip, 0) + 1

        ip_duration_tracker.setdefault(src_ip, {"total_source_bytes": 0, "start_time": current_time})
        ip_duration_tracker.setdefault(dst_ip, {"total_dest_bytes": 0, "start_time": current_time})

        ip_packet_tracker.setdefault(pair_key, {"spkts": 0, "dpkts": 0})
        ip_packet_tracker.setdefault(reverse_pair_key, {"spkts": 0, "dpkts": 0})

        ip_packet_tracker[pair_key]["spkts"] += 1
        ip_packet_tracker[reverse_pair_key]["dpkts"] += 1

        src_elapsed_time = current_time - ip_duration_tracker[src_ip]["start_time"]
        dst_elapsed_time = current_time - ip_duration_tracker[dst_ip]["start_time"]

        src_bytes = len(packet)
        dst_bytes = len(packet)
        ip_duration_tracker[src_ip]["total_source_bytes"] += src_bytes
        ip_duration_tracker[dst_ip]["total_dest_bytes"] += dst_bytes

        sload = (ip_duration_tracker[src_ip]["total_source_bytes"] / src_elapsed_time) * 8 if src_elapsed_time > 0 else 0
        dload = (ip_duration_tracker[dst_ip]["total_dest_bytes"] / dst_elapsed_time) * 8 if dst_elapsed_time > 0 else 0

        source_avg_tracker.setdefault(src_ip, {"total_size": 0, "packet_count": 0})
        dest_avg_tracker.setdefault(dst_ip, {"total_size": 0, "packet_count": 0})

        source_avg_tracker[src_ip]["total_size"] += src_bytes
        source_avg_tracker[src_ip]["packet_count"] += 1
        dest_avg_tracker[dst_ip]["total_size"] += dst_bytes
        dest_avg_tracker[dst_ip]["packet_count"] += 1

        smean = source_avg_tracker[src_ip]["total_size"] / source_avg_tracker[src_ip]["packet_count"]
        dmean = dest_avg_tracker[dst_ip]["total_size"] / dest_avg_tracker[dst_ip]["packet_count"]

        packet_info = {
            "dur": src_elapsed_time,
            "proto": packet[0][1].proto if hasattr(packet[0][1], "proto") else "Unknown",
            "service": service,
            "state": state,
            "sbytes": src_bytes,
            "dbytes": dst_bytes,
            "sttl": sttl,
            "dttl": dttl,
            "sload": sload,
            "dload": dload,
            "spkts": ip_packet_tracker[pair_key]["spkts"],
            "dpkts": ip_packet_tracker[reverse_pair_key]["dpkts"],
            "ct_srv_src": service_src_connections[service_src_key],
            "ct_srv_dst": service_dst_connections[service_dst_key],
            "ct_dst_ltm": dest_connections[dst_ip],
            "smean": smean,
            "dmean": dmean,
            "attack_cat": ""
        }

        packet_batch.append(packet_info)

        if len(packet_batch) >= batch_size:
            packet_df = pd.DataFrame(packet_batch)
            packet_queue.put(packet_df)
            packet_batch = []

    except Exception as e:
        pass  # Safely handle any exceptions and skip the packet

# Function to process and predict batches with a confidence threshold
def process_batches(threshold=0.9):
    while True:
        if not packet_queue.empty():
            packet_df = packet_queue.get()

            # Preprocess the packet data
            processed_df = preprocess_packet_data(packet_df, label_encoders)

            # Get prediction probabilities from the model
            predictions_proba = model.predict(processed_df.drop(columns=["attack_cat"]))

            # Get the predicted classes (labels)
            predicted_classes = np.argmax(predictions_proba, axis=1)

            # Calculate confidence for each prediction
            confidence_scores = np.max(predictions_proba, axis=1)

            # Apply threshold: Mark as attack only if confidence exceeds the threshold
            predicted_categories = []
            for idx, score in enumerate(confidence_scores):
                if score >= threshold:
                    # Convert to attack category using label encoding
                    predicted_categories.append(label_encoders["attack_cat"].inverse_transform([predicted_classes[idx]])[0])
                else:
                    predicted_categories.append("Normal")  # Mark as Normal if confidence is below threshold

            # Update the dataframe with the predicted categories
            packet_df["attack_cat"] = predicted_categories

            # Print or store the result as needed
            print(packet_df)

# Start sniffing packets
def start_sniffing():
    sniff(filter="ip", prn=packet_callback, store=0)

# Start batch processing in a background thread
def start_processing():
    processing_thread = threading.Thread(target=process_batches, args=(0.9,)) 
    processing_thread.daemon = True
    processing_thread.start()

# Run the program
start_processing()
start_sniffing()



4/422ms/step━━━━━━━━━━━━2/40s3/4━━━━━━━━━━ ━━━━━━━━━━ ━━━━━━━━━━━━━━━0s━━━━━  11ms/step0s 9ms/st ━━━━━━━━━━━━━━━━━━━━ 0s 28ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step
         dur  proto service state  sbytes  dbytes  sttl  dttl        sload  \
0   0.000000     17       -            75      75     1     1     0.000000   
1   0.000000     17       -            76      76     1     1     0.000000   
2   0.000000     17       -            72      72    10    10     0.000000   
3   0.000000     17       -           291     291   255   255     0.000000   
4   0.000000     17       -            72      72     1     1     0.000000   
..       ...    ...     ...   ...     ...     ...   ...   ...          ...   
95  3.088014     17       -           148     148     1     1  1533.671859   
96  3.089194     17       -           148     148     1     1  1916.357434   
97  3.089194     17       -           148     148     1     1  2299.628921   
98  4.493379     17       -           110     110   128 