In [1]:
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb import Settings
import os
from collections import Counter
import kagglehub
import re
from imblearn.over_sampling import RandomOverSampler
import json
import time
import gc

from data_loading.load_luflow import get_luflow
from data_loading.tools import reduce_mem_usage

vector_db_path = os.path.join(os.getcwd(), os.pardir, os.pardir, "data/vector_db")
client = chromadb.PersistentClient(path=vector_db_path, settings=Settings(allow_reset=True, 
                                                                          anonymized_telemetry=False))

use_luflow = False
use_uq = True
collection_name = "packet_data"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if use_luflow:
    data = get_luflow(raw=True)
else:
    # Path to the cached dataset
    cache_path = os.path.expanduser("~/.cache/kagglehub/datasets")
    data_path = os.path.join(cache_path, "aryashah2k/nfuqnidsv2-network-intrusion-detection-dataset/versions/1")

    if not os.path.exists(data_path):
        # Download latest version
        data_path = kagglehub.dataset_download("aryashah2k/nfuqnidsv2-network-intrusion-detection-dataset")
    
    data = pd.read_csv(os.path.join(data_path, "NF-UQ-NIDS-v2.csv"), nrows=200_000)

data = data[data["Label"] != 0]
data = reduce_mem_usage(data)
data.head()

Memory usage after optimization is: 20.35 MB
Decreased by 57.7%


  has_large_values = (abs_vals > 1e6).any()
  has_large_values = (abs_vals > 1e6).any()


Unnamed: 0,IPV4_SRC_ADDR,L4_SRC_PORT,IPV4_DST_ADDR,L4_DST_PORT,PROTOCOL,L7_PROTO,IN_BYTES,IN_PKTS,OUT_BYTES,OUT_PKTS,...,TCP_WIN_MAX_OUT,ICMP_TYPE,ICMP_IPV4_TYPE,DNS_QUERY_ID,DNS_QUERY_TYPE,DNS_TTL_ANSWER,FTP_COMMAND_RET_CODE,Label,Attack,Dataset
0,192.168.100.148,65389,192.168.100.7,80,6,7.0,420,3,0,0,...,0,35840,140,0,0,0,0.0,1,DoS,NF-BoT-IoT-v2
1,192.168.100.148,11154,192.168.100.5,80,6,7.0,280,2,40,1,...,0,0,0,0,0,0,0.0,1,DoS,NF-BoT-IoT-v2
7,192.168.1.31,54001,192.168.1.180,22,6,92.0,84,2,88,2,...,29200,0,0,0,0,0,0.0,1,scanning,NF-ToN-IoT-v2
8,192.168.100.147,33372,192.168.100.5,80,6,7.0,280,2,40,1,...,0,35840,140,0,0,0,0.0,1,DoS,NF-BoT-IoT-v2
9,192.168.1.31,37085,192.168.1.193,1863,6,0.0,44,1,40,1,...,0,0,0,0,0,0,0.0,1,scanning,NF-ToN-IoT-v2


In [86]:
data['Attack'].value_counts()

Attack
DDoS              57505
DoS               47267
scanning           9947
Reconnaissance     6935
xss                6410
password           3071
injection          1809
Bot                 371
Brute Force         344
Infilteration       315
Exploits             77
Fuzzers              56
Backdoor             52
Generic              33
mitm                 19
ransomware           10
Analysis              7
Theft                 3
Shellcode             3
Worms                 1
Name: count, dtype: int64

In [3]:
X = data.drop(columns=['Attack'])
y = data['Attack']
oversampler = RandomOverSampler(sampling_strategy='all')
X_over, y_over = oversampler.fit_resample(X, y)
print(f"Under-sampled class distribution: {Counter(y_over)}")

Under-sampled class distribution: Counter({'DoS': 57505, 'scanning': 57505, 'DDoS': 57505, 'xss': 57505, 'Bot': 57505, 'Reconnaissance': 57505, 'password': 57505, 'Fuzzers': 57505, 'injection': 57505, 'Theft': 57505, 'Brute Force': 57505, 'Infilteration': 57505, 'Exploits': 57505, 'Generic': 57505, 'Analysis': 57505, 'Backdoor': 57505, 'mitm': 57505, 'Shellcode': 57505, 'ransomware': 57505, 'Worms': 57505})


In [4]:
len(X_over)

1150100

In [4]:
# only keep 10% of the oversampled data
X_over = X_over.sample(frac=0.1)
y_over = y_over[X_over.index]
len(X_over)

115010

In [5]:
data = X_over
data['Attack'] = y_over
data.head()

  has_large_values = (abs_vals > 1e6).any()
  has_large_values = (abs_vals > 1e6).any()


Unnamed: 0,IPV4_SRC_ADDR,L4_SRC_PORT,IPV4_DST_ADDR,L4_DST_PORT,PROTOCOL,L7_PROTO,IN_BYTES,IN_PKTS,OUT_BYTES,OUT_PKTS,...,TCP_WIN_MAX_OUT,ICMP_TYPE,ICMP_IPV4_TYPE,DNS_QUERY_ID,DNS_QUERY_TYPE,DNS_TTL_ANSWER,FTP_COMMAND_RET_CODE,Label,Dataset,Attack
356817,18.221.219.4,33540,172.31.69.25,21,6,1.0,780,13,520,13,...,0,0,0,0,0,0,0.0,1,NF-CSE-CIC-IDS2018-v2,Brute Force
1028813,192.168.1.33,44501,192.168.1.193,445,6,41.0,7317,24,304,6,...,8192,22528,88,0,0,0,0.0,1,NF-ToN-IoT-v2,ransomware
310517,13.58.98.64,42240,172.31.69.25,22,6,92.0,3184,24,3869,23,...,26847,0,0,0,0,0,0.0,1,NF-CSE-CIC-IDS2018-v2,Brute Force
174057,175.45.176.0,0,149.171.126.17,0,216,0.0,200,2,0,0,...,0,0,0,0,0,0,0.0,1,NF-UNSW-NB15-v2,Analysis
335670,13.58.98.64,57588,172.31.69.25,22,6,92.0,3116,23,3869,23,...,26847,0,0,0,0,0,0.0,1,NF-CSE-CIC-IDS2018-v2,Brute Force


In [92]:
print(f"Under-sampled class distribution: {Counter(y_over)}")

Under-sampled class distribution: Counter({'Worms': 5867, 'mitm': 5853, 'password': 5852, 'Backdoor': 5820, 'Exploits': 5799, 'Shellcode': 5796, 'Theft': 5778, 'Brute Force': 5756, 'Generic': 5756, 'Bot': 5755, 'xss': 5753, 'Reconnaissance': 5736, 'Analysis': 5736, 'DoS': 5714, 'Fuzzers': 5712, 'Infilteration': 5710, 'ransomware': 5702, 'scanning': 5685, 'injection': 5650, 'DDoS': 5580})


In [51]:
data.columns

Index(['IPV4_SRC_ADDR', 'L4_SRC_PORT', 'IPV4_DST_ADDR', 'L4_DST_PORT',
       'PROTOCOL', 'L7_PROTO', 'IN_BYTES', 'IN_PKTS', 'OUT_BYTES', 'OUT_PKTS',
       'TCP_FLAGS', 'CLIENT_TCP_FLAGS', 'SERVER_TCP_FLAGS',
       'FLOW_DURATION_MILLISECONDS', 'DURATION_IN', 'DURATION_OUT', 'MIN_TTL',
       'MAX_TTL', 'LONGEST_FLOW_PKT', 'SHORTEST_FLOW_PKT', 'MIN_IP_PKT_LEN',
       'MAX_IP_PKT_LEN', 'SRC_TO_DST_SECOND_BYTES', 'DST_TO_SRC_SECOND_BYTES',
       'RETRANSMITTED_IN_BYTES', 'RETRANSMITTED_IN_PKTS',
       'RETRANSMITTED_OUT_BYTES', 'RETRANSMITTED_OUT_PKTS',
       'SRC_TO_DST_AVG_THROUGHPUT', 'DST_TO_SRC_AVG_THROUGHPUT',
       'NUM_PKTS_UP_TO_128_BYTES', 'NUM_PKTS_128_TO_256_BYTES',
       'NUM_PKTS_256_TO_512_BYTES', 'NUM_PKTS_512_TO_1024_BYTES',
       'NUM_PKTS_1024_TO_1514_BYTES', 'TCP_WIN_MAX_IN', 'TCP_WIN_MAX_OUT',
       'ICMP_TYPE', 'ICMP_IPV4_TYPE', 'DNS_QUERY_ID', 'DNS_QUERY_TYPE',
       'DNS_TTL_ANSWER', 'FTP_COMMAND_RET_CODE', 'Label', 'Dataset', 'Attack'],
      dtype='ob

In [7]:
def prepare_packet_documents_luflow(df):
    documents = []
    
    print("Preparing packet documents...")
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        # Create a structured text representation of each flow
        doc_parts = []
        doc_parts.append(f"Flow ID: {idx}")
        
        # Basic flow metrics
        doc_parts.append(f"Average IPT: {row.get('avg_ipt', 'unknown')}")
        doc_parts.append(f"Bytes In: {row.get('bytes_in', 'unknown')}")
        doc_parts.append(f"Bytes Out: {row.get('bytes_out', 'unknown')}")
        
        # Connection details
        doc_parts.append(f"Connection: {row.get('src_ip', 'unknown')}:{row.get('src_port', 'unknown')} → {row.get('dest_ip', 'unknown')}:{row.get('dest_port', 'unknown')}")
        
        # Protocol information
        doc_parts.append(f"Protocol: {row.get('proto', 'unknown')}")
        
        # Packet counts
        doc_parts.append(f"Packets In: {row.get('num_pkts_in', 'unknown')}")
        doc_parts.append(f"Packets Out: {row.get('num_pkts_out', 'unknown')}")
        
        # Entropy and statistical features
        doc_parts.append(f"Entropy: {row.get('entropy', 'unknown')}")
        doc_parts.append(f"Total Entropy: {row.get('total_entropy', 'unknown')}")
        
        # Timing information
        doc_parts.append(f"Start Time: {row.get('time_start', 'unknown')}")
        doc_parts.append(f"End Time: {row.get('time_end', 'unknown')}")
        doc_parts.append(f"Duration: {row.get('duration', 'unknown')}")
        
        # Label and temporal context
        doc_parts.append(f"Label: {row.get('label', 'unknown')}")
        doc_parts.append(f"Date: {row.get('Year', 'unknown')}-{row.get('Month', 'unknown')}-{row.get('Day', 'unknown')}")
        
        # Combine all parts into a single document
        document = "\n".join(doc_parts)
        
        # Store the document with its metadata
        documents.append({
            "id": idx, 
            "content": document, 
            "metadata": {
                "avg_ipt": row.get('avg_ipt', 0),
                "bytes_in": row.get('bytes_in', 0),
                "bytes_out": row.get('bytes_out', 0),
                "dest_ip": row.get('dest_ip', ''),
                "dest_port": row.get('dest_port', 0),
                "entropy": row.get('entropy', 0),
                "num_pkts_out": row.get('num_pkts_out', 0),
                "num_pkts_in": row.get('num_pkts_in', 0),
                "proto": row.get('proto', 0),
                "src_ip": row.get('src_ip', ''),
                "src_port": row.get('src_port', 0),
                "time_end": row.get('time_end', 0),
                "time_start": row.get('time_start', 0),
                "total_entropy": row.get('total_entropy', 0),
                "label": row.get('label', 0),
                "duration": row.get('duration', 0),
                "year": row.get('Year', 0),
                "month": row.get('Month', 0),
                "day": row.get('Day', 0)
            }
        })
    
    return documents

In [8]:
def prepare_packet_documents_uq(df):
    documents = []
    
    for idx, row in df.iterrows():
        # Create a structured text representation of each flow
        doc_parts = []
        doc_parts.append(f"Flow ID: {idx}")
        
        # Connection details
        doc_parts.append(f"Connection: {row.get('IPV4_SRC_ADDR', 'unknown')}:{row.get('L4_SRC_PORT', 'unknown')} → {row.get('IPV4_DST_ADDR', 'unknown')}:{row.get('L4_DST_PORT', 'unknown')}")
        
        # Protocol information
        doc_parts.append(f"Protocol: {row.get('PROTOCOL', 'unknown')} (L7: {row.get('L7_PROTO', 'unknown')})")
        
        # Volume metrics
        doc_parts.append(f"Traffic Volume: {row.get('IN_BYTES', 0) + row.get('OUT_BYTES', 0)} bytes total ({row.get('IN_PKTS', 0) + row.get('OUT_PKTS', 0)} packets)")
        doc_parts.append(f"Inbound: {row.get('IN_BYTES', 0)} bytes in {row.get('IN_PKTS', 0)} packets")
        doc_parts.append(f"Outbound: {row.get('OUT_BYTES', 0)} bytes in {row.get('OUT_PKTS', 0)} packets")
        
        # TCP specific information
        if row.get('PROTOCOL') == 6:  # TCP
            doc_parts.append(f"TCP Flags: {row.get('TCP_FLAGS', 'unknown')}")
            doc_parts.append(f"Client TCP Flags: {row.get('CLIENT_TCP_FLAGS', 'unknown')}")
            doc_parts.append(f"Server TCP Flags: {row.get('SERVER_TCP_FLAGS', 'unknown')}")
            doc_parts.append(f"TCP Window Max (In/Out): {row.get('TCP_WIN_MAX_IN', 'unknown')}/{row.get('TCP_WIN_MAX_OUT', 'unknown')}")
            
            # Retransmission stats
            doc_parts.append(f"Retransmissions: {row.get('RETRANSMITTED_IN_PKTS', 0) + row.get('RETRANSMITTED_OUT_PKTS', 0)} packets ({row.get('RETRANSMITTED_IN_BYTES', 0) + row.get('RETRANSMITTED_OUT_BYTES', 0)} bytes)")
        
        # ICMP specific information
        if row.get('PROTOCOL') == 1:  # ICMP
            doc_parts.append(f"ICMP Type: {row.get('ICMP_TYPE', 'unknown')}")
            doc_parts.append(f"ICMP IPv4 Type: {row.get('ICMP_IPV4_TYPE', 'unknown')}")
        
        # DNS specific information
        if row.get('DNS_QUERY_ID') is not None:
            doc_parts.append(f"DNS Query ID: {row.get('DNS_QUERY_ID', 'unknown')}")
            doc_parts.append(f"DNS Query Type: {row.get('DNS_QUERY_TYPE', 'unknown')}")
            doc_parts.append(f"DNS TTL Answer: {row.get('DNS_TTL_ANSWER', 'unknown')}")
        
        # FTP specific information
        if row.get('FTP_COMMAND_RET_CODE') is not None:
            doc_parts.append(f"FTP Return Code: {row.get('FTP_COMMAND_RET_CODE', 'unknown')}")
        
        # Timing information
        doc_parts.append(f"Flow Duration: {row.get('FLOW_DURATION_MILLISECONDS', 'unknown')} ms")
        doc_parts.append(f"Duration In/Out: {row.get('DURATION_IN', 'unknown')}/{row.get('DURATION_OUT', 'unknown')}")
        
        # Throughput
        doc_parts.append(f"Throughput Client→Server: {row.get('SRC_TO_DST_AVG_THROUGHPUT', 'unknown')} bytes/s")
        doc_parts.append(f"Throughput Server→Client: {row.get('DST_TO_SRC_AVG_THROUGHPUT', 'unknown')} bytes/s")
        
        # Packet size distribution
        doc_parts.append("Packet Size Distribution:")
        doc_parts.append(f"  ≤128 bytes: {row.get('NUM_PKTS_UP_TO_128_BYTES', 0)} packets")
        doc_parts.append(f"  128-256 bytes: {row.get('NUM_PKTS_128_TO_256_BYTES', 0)} packets")
        doc_parts.append(f"  256-512 bytes: {row.get('NUM_PKTS_256_TO_512_BYTES', 0)} packets")
        doc_parts.append(f"  512-1024 bytes: {row.get('NUM_PKTS_512_TO_1024_BYTES', 0)} packets")
        doc_parts.append(f"  1024-1514 bytes: {row.get('NUM_PKTS_1024_TO_1514_BYTES', 0)} packets")
        
        # TTL and packet lengths
        doc_parts.append(f"TTL Range: {row.get('MIN_TTL', 'unknown')}-{row.get('MAX_TTL', 'unknown')}")
        doc_parts.append(f"Packet Length Range: {row.get('MIN_IP_PKT_LEN', 'unknown')}-{row.get('MAX_IP_PKT_LEN', 'unknown')}")
        doc_parts.append(f"Shortest/Longest Packet: {row.get('SHORTEST_FLOW_PKT', 'unknown')}/{row.get('LONGEST_FLOW_PKT', 'unknown')}")
        
        # Classification information
        if 'Label' in row:
            doc_parts.append(f"Label: {row.get('Label', 'unknown')}")
        if 'Attack' in row:
            doc_parts.append(f"Attack: {row.get('Attack', 'unknown')}")
        if 'Dataset' in row:
            doc_parts.append(f"Dataset: {row.get('Dataset', 'unknown')}")
        
        # Combine all parts into a single document
        document = "\n".join(doc_parts)
        
        # Store the document with its metadata
        # Convert metadata to appropriate types to avoid issues
        metadata = {}
        for col in row.index:
            value = row.get(col)
            # Handle NaN values and convert to basic types that ChromaDB can handle
            if pd.isna(value):
                continue
            elif isinstance(value, (int, float, str, bool)):
                metadata[col] = value
            else:
                # Convert other types to strings
                metadata[col] = str(value)
        
        documents.append({
            "id": idx, 
            "content": document, 
            "metadata": metadata
        })
    
    return documents

In [9]:
def create_embeddings(documents):
    # Load a pretrained model
    model = SentenceTransformer('all-MiniLM-L6-v2')
    
    print("Extracting text content from documents...")
    # Extract just the text content
    texts = [doc["content"] for doc in tqdm(documents)]
    
    # Generate embeddings
    embeddings = model.encode(texts)
    
    print(f"Generated {len(embeddings)} embeddings")
    print("Adding embeddings to documents...")
    # Add embeddings to documents
    for i, doc in tqdm(enumerate(documents)):
        doc["embedding"] = embeddings[i]
    
    return documents

In [10]:
def store_in_vector_db(documents, client):
    # Create or get a collection
    collection = client.get_or_create_collection("packet_data")
    
    # Prepare data for insertion
    ids = [str(doc["id"]) for doc in documents]
    embeddings = [doc["embedding"].tolist() for doc in documents]
    metadatas = [doc["metadata"] for doc in documents]
    documents_text = [doc["content"] for doc in documents]
    
    # Set batch size within ChromaDB's limits
    batch_size = 5000
    
    # Add documents in batches
    total_docs = len(ids)
    for i in tqdm(range(0, total_docs, batch_size)):
        end_idx = min(i + batch_size, total_docs)
        
        # Create batch
        batch_ids = ids[i:end_idx]
        batch_embeddings = embeddings[i:end_idx]
        batch_metadatas = metadatas[i:end_idx]
        batch_documents = documents_text[i:end_idx]
        
        # Add batch to collection
        collection.add(
            ids=batch_ids,
            embeddings=batch_embeddings,
            metadatas=batch_metadatas,
            documents=batch_documents
        )
    
    return collection

In [11]:
def retrieve_similar_packets(query, collection, n=5):
    # Load the same model used for document encoding
    model = SentenceTransformer('all-MiniLM-L6-v2')
    
    # Convert query to embedding
    query_embedding = model.encode(query).tolist()
    
    # Query the collection
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=n
    )
    
    return results

In [12]:
OVERWRITE_DB = False

if OVERWRITE_DB:
    client.reset()

test_size = 100
train_size = len(data) - test_size

if os.path.exists(vector_db_path) and not OVERWRITE_DB:
    collection = client.get_collection(collection_name)
    if use_luflow:
        test_data = data.iloc[train_size:]
    else:
        test_data = data.iloc[train_size:].drop(columns=["Label", "Dataset"])
    
else:
    # Convert to documents
    if use_luflow:
        train_data = data.iloc[:train_size].drop(columns=["label"])
        test_data = data.iloc[train_size:]
    else:
        train_data = data.iloc[:train_size].drop(columns=["Label", "Dataset"])
        test_data = data.iloc[train_size:].drop(columns=["Label", "Dataset"])
    
    train_documents = prepare_packet_documents_luflow(train_data) if use_luflow else prepare_packet_documents_uq(train_data)
    # Create embeddings
    embedded_docs = create_embeddings(train_documents)
    collection = store_in_vector_db(embedded_docs, client)

In [13]:
test_documents = prepare_packet_documents_luflow(test_data) if use_luflow else prepare_packet_documents_uq(test_data)

In [14]:
example_doc = test_documents[0].copy()
example_doc['content'] = re.sub(r'Attack: .+', 'Attack: <unknown>', example_doc['content'])
example_doc['content']

'Flow ID: 1062102\nConnection: 192.168.1.32:52904 → 192.168.35.76:443\nProtocol: 6 (L7: 91.0)\nTraffic Volume: 44 bytes total (1 packets)\nInbound: 44 bytes in 1 packets\nOutbound: 0 bytes in 0 packets\nTCP Flags: 2\nClient TCP Flags: 2\nServer TCP Flags: 0\nTCP Window Max (In/Out): 1024/0\nRetransmissions: 0 packets (0 bytes)\nDNS Query ID: 0\nDNS Query Type: 0\nDNS TTL Answer: 0\nFTP Return Code: 0.0\nFlow Duration: 0 ms\nDuration In/Out: 0/0\nThroughput Client→Server: 352000 bytes/s\nThroughput Server→Client: 0 bytes/s\nPacket Size Distribution:\n  ≤128 bytes: 1 packets\n  128-256 bytes: 0 packets\n  256-512 bytes: 0 packets\n  512-1024 bytes: 0 packets\n  1024-1514 bytes: 0 packets\nTTL Range: 0-0\nPacket Length Range: 0-44\nShortest/Longest Packet: 44/44\nAttack: <unknown>'

In [132]:
results = retrieve_similar_packets(example_doc['content'], collection, n=5)
metadata = results['metadatas']
attacks = [match_dict['Attack'] for match_dict in metadata[0]]
attack_counts = Counter(attacks)
attack_counts.most_common(1)

[('mitm', 5)]

In [31]:
example_doc['metadata']['Attack']

'scanning'

In [99]:
def validate_vector_db(test_docs):
    correct = 0
    total = 0
    
    for doc in tqdm(test_docs):
        # Query the collection
        doc_copy = doc.copy()
        doc_copy['content'] = re.sub(r'Attack: .+', 'Attack: <unknown>', doc_copy['content'])
        results = retrieve_similar_packets(doc_copy['content'], collection, n=5)
        metadata = results['metadatas']
        attacks = [match_dict['Attack'] for match_dict in metadata[0]]
        attack_counts = Counter(attacks)
        most_common_attack = attack_counts.most_common(1)[0][0]
        
        if most_common_attack == doc['metadata']['Attack']:
            correct += 1
        else:
            print(f"Expected: {doc['metadata']['Attack']}, Predicted: {most_common_attack}")
        total += 1
    
    return correct / total
    
test_accuracy = validate_vector_db(test_documents)
test_accuracy

  7%|▋         | 7/100 [00:15<02:20,  1.51s/it]

Expected: DDoS, Predicted: injection


 11%|█         | 11/100 [00:17<01:19,  1.11it/s]

Expected: Reconnaissance, Predicted: scanning


 13%|█▎        | 13/100 [00:19<01:16,  1.13it/s]

Expected: ransomware, Predicted: Backdoor


 14%|█▍        | 14/100 [00:20<01:09,  1.24it/s]

Expected: Reconnaissance, Predicted: scanning


 35%|███▌      | 35/100 [00:35<00:45,  1.42it/s]

Expected: xss, Predicted: injection


 39%|███▉      | 39/100 [00:38<00:42,  1.43it/s]

Expected: DDoS, Predicted: injection


 40%|████      | 40/100 [00:39<00:45,  1.32it/s]

Expected: Reconnaissance, Predicted: DoS


 44%|████▍     | 44/100 [00:42<00:38,  1.46it/s]

Expected: ransomware, Predicted: Backdoor


 45%|████▌     | 45/100 [00:42<00:36,  1.49it/s]

Expected: Reconnaissance, Predicted: scanning


 60%|██████    | 60/100 [00:56<00:37,  1.06it/s]

Expected: password, Predicted: injection


 61%|██████    | 61/100 [00:57<00:35,  1.10it/s]

Expected: DDoS, Predicted: injection


 62%|██████▏   | 62/100 [00:57<00:33,  1.13it/s]

Expected: xss, Predicted: injection


 63%|██████▎   | 63/100 [00:58<00:31,  1.18it/s]

Expected: xss, Predicted: injection


 65%|██████▌   | 65/100 [00:59<00:26,  1.34it/s]

Expected: xss, Predicted: DoS


 70%|███████   | 70/100 [01:03<00:22,  1.31it/s]

Expected: DDoS, Predicted: Generic


 77%|███████▋  | 77/100 [01:08<00:17,  1.35it/s]

Expected: xss, Predicted: injection


 79%|███████▉  | 79/100 [01:09<00:14,  1.46it/s]

Expected: Backdoor, Predicted: Reconnaissance


 82%|████████▏ | 82/100 [01:13<00:16,  1.06it/s]

Expected: DDoS, Predicted: Generic


 87%|████████▋ | 87/100 [01:19<00:14,  1.14s/it]

Expected: ransomware, Predicted: Backdoor


 94%|█████████▍| 94/100 [01:24<00:04,  1.24it/s]

Expected: DDoS, Predicted: Generic


 98%|█████████▊| 98/100 [01:28<00:01,  1.19it/s]

Expected: xss, Predicted: injection


 99%|█████████▉| 99/100 [01:28<00:00,  1.19it/s]

Expected: Reconnaissance, Predicted: DoS


100%|██████████| 100/100 [01:29<00:00,  1.12it/s]


0.78

In [72]:
def query_with_filters(collection, query_text, metadata_filters, n_results=5):
    """Query with both semantic search and metadata filters."""
    
    # Load embedding model
    model = SentenceTransformer('all-MiniLM-L6-v2')
    
    # Encode the query
    query_embedding = model.encode(query_text).tolist()
    
    # Convert flat filters to proper ChromaDB format
    # ChromaDB requires a single operator at the top level
    formatted_where = {"$and": []}
    
    for key, value in metadata_filters.items():
        if isinstance(value, dict):
            # This is already an operator (like {"$gt": 4.0})
            formatted_where["$and"].append({key: value})
        else:
            # This is a direct value match
            formatted_where["$and"].append({key: {"$eq": value}})
    
    # Search the collection with properly formatted filters
    results = collection.query(
        query_embeddings=[query_embedding],
        where=formatted_where,
        n_results=n_results
    )
    
    return results

In [15]:
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

model_name = "ehristoforu/coolqwen-3b-it"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Create the text-generation pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1000)

Downloading shards: 100%|██████████| 2/2 [02:25<00:00, 72.53s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.68s/it]
Device set to use cuda:0


In [16]:
import json
import time

# Function to retrieve similar network packets for LLM context
def retrieve_context_for_llm(query_packet, collection, n=5):
    results = retrieve_similar_packets(query_packet, collection, n=n)
    context = "SIMILAR NETWORK TRAFFIC EXAMPLES:\n\n"
    if 'documents' in results and results['documents'][0]:
        for i, doc in enumerate(results['documents'][0]):
            distance = results['distances'][0][i] if 'distances' in results else "unknown"
            context += f"--- Example {i+1} (Similarity: {1/(1+distance):.2f}) ---\n"
            context += doc + "\n\n"
    return context

# Function to query an LLM using Hugging Face Transformers
def query_llm(prompt, model, max_retries=3):
    for attempt in range(max_retries):
        try:
            response = model(prompt, return_full_text=False)
            return response[0]['generated_text']
        except Exception as e:
            if attempt < max_retries - 1:
                sleep_time = 2 ** attempt
                print(f"Error querying LLM: {e}. Retrying in {sleep_time}s...")
                time.sleep(sleep_time)
            else:
                print(f"Failed to query LLM after {max_retries} attempts: {e}")
                return "{}"

In [29]:
def nids_rag_analysis(packet_content, collection, model):
    """
    Perform RAG-based network intrusion detection
    
    Args:
        packet_content: The formatted packet content to analyze
        collection: Vector database collection containing network traffic examples
        llm_client: Client for your LLM (e.g., OpenAI, Anthropic, etc.)
        model: SentenceTransformer model for embedding
    
    Returns:
        Dict with analysis results
    """
    # Retrieve similar examples as context
    context = retrieve_context_for_llm(packet_content, collection)
    
    # Construct prompt for the LLM
    prompt = f"""You are a network security expert analyzing network traffic for potential intrusions.
    
I'll provide you with:
1. A network packet to analyze
2. Examples of similar network traffic from our database with known classifications

Based on the examples and your expertise, determine:
- If this packet is malicious or benign
- If malicious, what type of attack it represents (e.g., DoS, Port Scan, DDoS, etc.)
- Confidence level (0-1)
- Explanation of your reasoning

Here is the packet to analyze:

{packet_content}

Here are similar packets from our database for context:

{context}

Provide your analysis in JSON format with the following fields:
{{
  "is_malicious": true/false,
  "attack_type": "attack name or null if benign",
  "confidence": 0.xx,
  "explanation": "your detailed reasoning"
}}

Return only valid JSON, with no additional text.
"""
    
    # Send to LLM and get response
    response = query_llm(prompt, model)
    # remove newlines from response
    response = response.replace("\n", "")

    # Extract JSON from the response using regex
    json_matches = re.findall(r'\{.*?\}', response, re.DOTALL)
    
    if json_matches:
        json_str = json_matches[0]
        try:
            analysis = json.loads(json_str)
            return analysis
        except json.JSONDecodeError:
            pass
    
    # Parse the JSON response
    try:
        analysis = json.loads(response)
        return analysis
    except json.JSONDecodeError:
        # Fallback if LLM doesn't return valid JSON
        return {
            "is_malicious": None,
            "attack_type": None,
            "confidence": 0,
            "explanation": "Error parsing LLM response",
            "raw_response": response
        }

In [30]:
nids_rag_analysis(example_doc['content'], collection, pipe)

{'is_malicious': True,
 'attack_type': 'scanning',
 'confidence': 0.99,
 'explanation': 'The packet has a high similarity score (0.99) with other benign scanning packets. It fits the pattern of a scanning attack, where the client initiates a TCP connection without establishing a session, sending small packets over a wide range of ports to identify open services on the target host.'}