In [73]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb import Settings
import json
import os
from collections import Counter
import kagglehub

from data_loading.load_luflow import get_luflow
from data_loading.tools import reduce_mem_usage

vector_db_path = os.path.join(os.getcwd(), os.pardir, os.pardir, "data/vector_db")
client = chromadb.PersistentClient(path=vector_db_path, settings=Settings(allow_reset=True, 
                                                                          anonymized_telemetry=False))

use_luflow = False
use_uq = True
collection_name = "packet_data"

In [92]:
if use_luflow:
    data = get_luflow(raw=True)
else:
    # Path to the cached dataset
    cache_path = os.path.expanduser("~/.cache/kagglehub/datasets")
    data_path = os.path.join(cache_path, "aryashah2k/nfuqnidsv2-network-intrusion-detection-dataset/versions/1")

    if not os.path.exists(data_path):
        # Download latest version
        data_path = kagglehub.dataset_download("aryashah2k/nfuqnidsv2-network-intrusion-detection-dataset")
    
    data = pd.read_csv(os.path.join(data_path, "NF-UQ-NIDS-v2.csv"), nrows=200_000)

data = reduce_mem_usage(data)
data.head()

Memory usage after optimization is: 33.76 MB
Decreased by 51.9%


  has_large_values = (abs_vals > 1e6).any()
  has_large_values = (abs_vals > 1e6).any()


Unnamed: 0,IPV4_SRC_ADDR,L4_SRC_PORT,IPV4_DST_ADDR,L4_DST_PORT,PROTOCOL,L7_PROTO,IN_BYTES,IN_PKTS,OUT_BYTES,OUT_PKTS,...,TCP_WIN_MAX_OUT,ICMP_TYPE,ICMP_IPV4_TYPE,DNS_QUERY_ID,DNS_QUERY_TYPE,DNS_TTL_ANSWER,FTP_COMMAND_RET_CODE,Label,Attack,Dataset
0,192.168.100.148,65389,192.168.100.7,80,6,7.0,420,3,0,0,...,0,35840,140,0,0,0,0.0,1,DoS,NF-BoT-IoT-v2
1,192.168.100.148,11154,192.168.100.5,80,6,7.0,280,2,40,1,...,0,0,0,0,0,0,0.0,1,DoS,NF-BoT-IoT-v2
2,192.168.1.31,42062,192.168.1.79,1041,6,0.0,44,1,40,1,...,0,0,0,0,0,0,0.0,0,Benign,NF-ToN-IoT-v2
3,192.168.1.34,46849,192.168.1.79,9110,6,0.0,44,1,40,1,...,0,0,0,0,0,0,0.0,0,Benign,NF-ToN-IoT-v2
4,192.168.1.30,50360,192.168.1.152,1084,6,0.0,44,1,40,1,...,0,0,0,0,0,0,0.0,0,Benign,NF-ToN-IoT-v2


In [75]:
data['Attack'].value_counts()

Attack
Benign            24793
DDoS              21491
DoS               17659
scanning           3694
Reconnaissance     2594
xss                2426
password           1129
injection           714
Bot                 148
Brute Force         148
Infilteration       122
Exploits             26
Fuzzers              18
Backdoor             17
Generic              13
mitm                  3
Shellcode             2
Theft                 1
Analysis              1
ransomware            1
Name: count, dtype: int64

In [62]:
data.columns

Index(['IPV4_SRC_ADDR', 'L4_SRC_PORT', 'IPV4_DST_ADDR', 'L4_DST_PORT',
       'PROTOCOL', 'L7_PROTO', 'IN_BYTES', 'IN_PKTS', 'OUT_BYTES', 'OUT_PKTS',
       'TCP_FLAGS', 'CLIENT_TCP_FLAGS', 'SERVER_TCP_FLAGS',
       'FLOW_DURATION_MILLISECONDS', 'DURATION_IN', 'DURATION_OUT', 'MIN_TTL',
       'MAX_TTL', 'LONGEST_FLOW_PKT', 'SHORTEST_FLOW_PKT', 'MIN_IP_PKT_LEN',
       'MAX_IP_PKT_LEN', 'SRC_TO_DST_SECOND_BYTES', 'DST_TO_SRC_SECOND_BYTES',
       'RETRANSMITTED_IN_BYTES', 'RETRANSMITTED_IN_PKTS',
       'RETRANSMITTED_OUT_BYTES', 'RETRANSMITTED_OUT_PKTS',
       'SRC_TO_DST_AVG_THROUGHPUT', 'DST_TO_SRC_AVG_THROUGHPUT',
       'NUM_PKTS_UP_TO_128_BYTES', 'NUM_PKTS_128_TO_256_BYTES',
       'NUM_PKTS_256_TO_512_BYTES', 'NUM_PKTS_512_TO_1024_BYTES',
       'NUM_PKTS_1024_TO_1514_BYTES', 'TCP_WIN_MAX_IN', 'TCP_WIN_MAX_OUT',
       'ICMP_TYPE', 'ICMP_IPV4_TYPE', 'DNS_QUERY_ID', 'DNS_QUERY_TYPE',
       'DNS_TTL_ANSWER', 'FTP_COMMAND_RET_CODE', 'Label', 'Attack', 'Dataset'],
      dtype='ob

In [76]:
def prepare_packet_documents_luflow(df):
    documents = []
    
    print("Preparing packet documents...")
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        # Create a structured text representation of each flow
        doc_parts = []
        doc_parts.append(f"Flow ID: {idx}")
        
        # Basic flow metrics
        doc_parts.append(f"Average IPT: {row.get('avg_ipt', 'unknown')}")
        doc_parts.append(f"Bytes In: {row.get('bytes_in', 'unknown')}")
        doc_parts.append(f"Bytes Out: {row.get('bytes_out', 'unknown')}")
        
        # Connection details
        doc_parts.append(f"Connection: {row.get('src_ip', 'unknown')}:{row.get('src_port', 'unknown')} → {row.get('dest_ip', 'unknown')}:{row.get('dest_port', 'unknown')}")
        
        # Protocol information
        doc_parts.append(f"Protocol: {row.get('proto', 'unknown')}")
        
        # Packet counts
        doc_parts.append(f"Packets In: {row.get('num_pkts_in', 'unknown')}")
        doc_parts.append(f"Packets Out: {row.get('num_pkts_out', 'unknown')}")
        
        # Entropy and statistical features
        doc_parts.append(f"Entropy: {row.get('entropy', 'unknown')}")
        doc_parts.append(f"Total Entropy: {row.get('total_entropy', 'unknown')}")
        
        # Timing information
        doc_parts.append(f"Start Time: {row.get('time_start', 'unknown')}")
        doc_parts.append(f"End Time: {row.get('time_end', 'unknown')}")
        doc_parts.append(f"Duration: {row.get('duration', 'unknown')}")
        
        # Label and temporal context
        doc_parts.append(f"Label: {row.get('label', 'unknown')}")
        doc_parts.append(f"Date: {row.get('Year', 'unknown')}-{row.get('Month', 'unknown')}-{row.get('Day', 'unknown')}")
        
        # Combine all parts into a single document
        document = "\n".join(doc_parts)
        
        # Store the document with its metadata
        documents.append({
            "id": idx, 
            "content": document, 
            "metadata": {
                "avg_ipt": row.get('avg_ipt', 0),
                "bytes_in": row.get('bytes_in', 0),
                "bytes_out": row.get('bytes_out', 0),
                "dest_ip": row.get('dest_ip', ''),
                "dest_port": row.get('dest_port', 0),
                "entropy": row.get('entropy', 0),
                "num_pkts_out": row.get('num_pkts_out', 0),
                "num_pkts_in": row.get('num_pkts_in', 0),
                "proto": row.get('proto', 0),
                "src_ip": row.get('src_ip', ''),
                "src_port": row.get('src_port', 0),
                "time_end": row.get('time_end', 0),
                "time_start": row.get('time_start', 0),
                "total_entropy": row.get('total_entropy', 0),
                "label": row.get('label', 0),
                "duration": row.get('duration', 0),
                "year": row.get('Year', 0),
                "month": row.get('Month', 0),
                "day": row.get('Day', 0)
            }
        })
    
    return documents

In [77]:
def prepare_packet_documents_uq(df):
    documents = []
    
    for idx, row in df.iterrows():
        # Create a structured text representation of each flow
        doc_parts = []
        doc_parts.append(f"Flow ID: {idx}")
        
        # Connection details
        doc_parts.append(f"Connection: {row.get('IPV4_SRC_ADDR', 'unknown')}:{row.get('L4_SRC_PORT', 'unknown')} → {row.get('IPV4_DST_ADDR', 'unknown')}:{row.get('L4_DST_PORT', 'unknown')}")
        
        # Protocol information
        doc_parts.append(f"Protocol: {row.get('PROTOCOL', 'unknown')} (L7: {row.get('L7_PROTO', 'unknown')})")
        
        # Volume metrics
        doc_parts.append(f"Traffic Volume: {row.get('IN_BYTES', 0) + row.get('OUT_BYTES', 0)} bytes total ({row.get('IN_PKTS', 0) + row.get('OUT_PKTS', 0)} packets)")
        doc_parts.append(f"Inbound: {row.get('IN_BYTES', 0)} bytes in {row.get('IN_PKTS', 0)} packets")
        doc_parts.append(f"Outbound: {row.get('OUT_BYTES', 0)} bytes in {row.get('OUT_PKTS', 0)} packets")
        
        # TCP specific information
        if row.get('PROTOCOL') == 6:  # TCP
            doc_parts.append(f"TCP Flags: {row.get('TCP_FLAGS', 'unknown')}")
            doc_parts.append(f"Client TCP Flags: {row.get('CLIENT_TCP_FLAGS', 'unknown')}")
            doc_parts.append(f"Server TCP Flags: {row.get('SERVER_TCP_FLAGS', 'unknown')}")
            doc_parts.append(f"TCP Window Max (In/Out): {row.get('TCP_WIN_MAX_IN', 'unknown')}/{row.get('TCP_WIN_MAX_OUT', 'unknown')}")
            
            # Retransmission stats
            doc_parts.append(f"Retransmissions: {row.get('RETRANSMITTED_IN_PKTS', 0) + row.get('RETRANSMITTED_OUT_PKTS', 0)} packets ({row.get('RETRANSMITTED_IN_BYTES', 0) + row.get('RETRANSMITTED_OUT_BYTES', 0)} bytes)")
        
        # ICMP specific information
        if row.get('PROTOCOL') == 1:  # ICMP
            doc_parts.append(f"ICMP Type: {row.get('ICMP_TYPE', 'unknown')}")
            doc_parts.append(f"ICMP IPv4 Type: {row.get('ICMP_IPV4_TYPE', 'unknown')}")
        
        # DNS specific information
        if row.get('DNS_QUERY_ID') is not None:
            doc_parts.append(f"DNS Query ID: {row.get('DNS_QUERY_ID', 'unknown')}")
            doc_parts.append(f"DNS Query Type: {row.get('DNS_QUERY_TYPE', 'unknown')}")
            doc_parts.append(f"DNS TTL Answer: {row.get('DNS_TTL_ANSWER', 'unknown')}")
        
        # FTP specific information
        if row.get('FTP_COMMAND_RET_CODE') is not None:
            doc_parts.append(f"FTP Return Code: {row.get('FTP_COMMAND_RET_CODE', 'unknown')}")
        
        # Timing information
        doc_parts.append(f"Flow Duration: {row.get('FLOW_DURATION_MILLISECONDS', 'unknown')} ms")
        doc_parts.append(f"Duration In/Out: {row.get('DURATION_IN', 'unknown')}/{row.get('DURATION_OUT', 'unknown')}")
        
        # Throughput
        doc_parts.append(f"Throughput Client→Server: {row.get('SRC_TO_DST_AVG_THROUGHPUT', 'unknown')} bytes/s")
        doc_parts.append(f"Throughput Server→Client: {row.get('DST_TO_SRC_AVG_THROUGHPUT', 'unknown')} bytes/s")
        
        # Packet size distribution
        doc_parts.append("Packet Size Distribution:")
        doc_parts.append(f"  ≤128 bytes: {row.get('NUM_PKTS_UP_TO_128_BYTES', 0)} packets")
        doc_parts.append(f"  128-256 bytes: {row.get('NUM_PKTS_128_TO_256_BYTES', 0)} packets")
        doc_parts.append(f"  256-512 bytes: {row.get('NUM_PKTS_256_TO_512_BYTES', 0)} packets")
        doc_parts.append(f"  512-1024 bytes: {row.get('NUM_PKTS_512_TO_1024_BYTES', 0)} packets")
        doc_parts.append(f"  1024-1514 bytes: {row.get('NUM_PKTS_1024_TO_1514_BYTES', 0)} packets")
        
        # TTL and packet lengths
        doc_parts.append(f"TTL Range: {row.get('MIN_TTL', 'unknown')}-{row.get('MAX_TTL', 'unknown')}")
        doc_parts.append(f"Packet Length Range: {row.get('MIN_IP_PKT_LEN', 'unknown')}-{row.get('MAX_IP_PKT_LEN', 'unknown')}")
        doc_parts.append(f"Shortest/Longest Packet: {row.get('SHORTEST_FLOW_PKT', 'unknown')}/{row.get('LONGEST_FLOW_PKT', 'unknown')}")
        
        # Classification information
        if 'Label' in row:
            doc_parts.append(f"Label: {row.get('Label', 'unknown')}")
        if 'Attack' in row:
            doc_parts.append(f"Attack: {row.get('Attack', 'unknown')}")
        if 'Dataset' in row:
            doc_parts.append(f"Dataset: {row.get('Dataset', 'unknown')}")
        
        # Combine all parts into a single document
        document = "\n".join(doc_parts)
        
        # Store the document with its metadata
        # Convert metadata to appropriate types to avoid issues
        metadata = {}
        for col in row.index:
            value = row.get(col)
            # Handle NaN values and convert to basic types that ChromaDB can handle
            if pd.isna(value):
                continue
            elif isinstance(value, (int, float, str, bool)):
                metadata[col] = value
            else:
                # Convert other types to strings
                metadata[col] = str(value)
        
        documents.append({
            "id": idx, 
            "content": document, 
            "metadata": metadata
        })
    
    return documents

In [78]:
def create_embeddings(documents):
    # Load a pretrained model
    model = SentenceTransformer('all-MiniLM-L6-v2')
    
    print("Extracting text content from documents...")
    # Extract just the text content
    texts = [doc["content"] for doc in tqdm(documents)]
    
    # Generate embeddings
    embeddings = model.encode(texts)
    
    print(f"Generated {len(embeddings)} embeddings")
    print("Adding embeddings to documents...")
    # Add embeddings to documents
    for i, doc in tqdm(enumerate(documents)):
        doc["embedding"] = embeddings[i]
    
    return documents

In [79]:
def store_in_vector_db(documents, client):
    # Create or get a collection
    collection = client.get_or_create_collection("packet_data")
    
    # Prepare data for insertion
    ids = [str(doc["id"]) for doc in documents]
    embeddings = [doc["embedding"].tolist() for doc in documents]
    metadatas = [doc["metadata"] for doc in documents]
    documents_text = [doc["content"] for doc in documents]
    
    # Set batch size within ChromaDB's limits
    batch_size = 5000
    
    # Add documents in batches
    total_docs = len(ids)
    for i in tqdm(range(0, total_docs, batch_size)):
        end_idx = min(i + batch_size, total_docs)
        
        # Create batch
        batch_ids = ids[i:end_idx]
        batch_embeddings = embeddings[i:end_idx]
        batch_metadatas = metadatas[i:end_idx]
        batch_documents = documents_text[i:end_idx]
        
        # Add batch to collection
        collection.add(
            ids=batch_ids,
            embeddings=batch_embeddings,
            metadatas=batch_metadatas,
            documents=batch_documents
        )
    
    return collection

In [80]:
def retrieve_similar_packets(query, collection, n=5):
    # Load the same model used for document encoding
    model = SentenceTransformer('all-MiniLM-L6-v2')
    
    # Convert query to embedding
    query_embedding = model.encode(query).tolist()
    print(query_embedding)
    
    # Query the collection
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=n
    )
    
    return results

In [81]:
OVERWRITE_DB = False

if OVERWRITE_DB:
    client.reset()

if os.path.exists(vector_db_path) and not OVERWRITE_DB:
    collection = client.get_collection(collection_name)
else:
    # Convert to documents
    if use_luflow:
        documents = prepare_packet_documents_luflow(data.drop(columns=["label"]))
    else:
        documents = prepare_packet_documents_uq(data.drop(columns=["Label", "Dataset"]))
    
    # Create embeddings
    embedded_docs = create_embeddings(documents)
    collection = store_in_vector_db(embedded_docs, client)

In [82]:
example_doc = documents[0]
example_doc['content'] = example_doc['content'].replace('Attack: DoS', 'Attack: <unknown>')
example_doc['content']

'Flow ID: 0\nConnection: 192.168.100.148:65389 → 192.168.100.7:80\nProtocol: 6 (L7: 7.0)\nTraffic Volume: 420 bytes total (3 packets)\nInbound: 420 bytes in 3 packets\nOutbound: 0 bytes in 0 packets\nTCP Flags: 2\nClient TCP Flags: 2\nServer TCP Flags: 0\nTCP Window Max (In/Out): 512/0\nRetransmissions: 1 packets (140 bytes)\nDNS Query ID: 0\nDNS Query Type: 0\nDNS TTL Answer: 0\nFTP Return Code: 0.0\nFlow Duration: 4293092 ms\nDuration In/Out: 1875/0\nThroughput Client→Server: 1120000 bytes/s\nThroughput Server→Client: 0 bytes/s\nPacket Size Distribution:\n  ≤128 bytes: 0 packets\n  128-256 bytes: 3 packets\n  256-512 bytes: 0 packets\n  512-1024 bytes: 0 packets\n  1024-1514 bytes: 0 packets\nTTL Range: 64-64\nPacket Length Range: 0-140\nShortest/Longest Packet: 140/140\nAttack: <unknown>'

In [83]:
results = retrieve_similar_packets(example_doc['content'], collection, n=5)
metadata = results['metadatas']
attacks = [match_dict['Attack'] for match_dict in metadata[0]]
attack_counts = Counter(attacks)
attack_counts.most_common(1)

[-0.05620570853352547, -0.005990129895508289, -0.016993943601846695, -0.026161985471844673, -0.05830743908882141, -0.0585225485265255, -0.007532202173024416, -0.04933980479836464, 0.045931700617074966, 0.00903173629194498, -0.06663363426923752, 0.0009082871838472784, -0.056930672377347946, -0.028181394562125206, -0.04223840311169624, -0.021119873970746994, 0.015627261251211166, -0.06866563856601715, 0.03683921694755554, -0.06406529992818832, 0.021532513201236725, 0.028900591656565666, -0.05274513363838196, -0.09050905704498291, -0.03244556114077568, 0.03792314976453781, -0.034447334706783295, 0.034374743700027466, 0.012925081886351109, -0.04453509673476219, -0.0003715877828653902, 0.0167424026876688, -0.11771264672279358, 0.006462490186095238, -0.025400720536708832, 0.02293000929057598, 0.0910630077123642, 0.062139272689819336, -0.04768531024456024, 0.06073344126343727, 0.0883118137717247, -0.05191110819578171, 0.06540928035974503, 0.11780250072479248, -0.002394604030996561, 0.00865774

[('DoS', 5)]

In [86]:
example_doc['metadata']['Attack']

'DoS'

In [94]:
test_docs = prepare_packet_documents_uq(data.iloc[:(len(data)) - 1].drop(columns=["Label", "Dataset"]))
test_docs[80_000]['content']

'Flow ID: 80000\nConnection: 114.43.207.102:50369 → 172.31.66.28:445\nProtocol: 6 (L7: 0.0)\nTraffic Volume: 840 bytes total (9 packets)\nInbound: 498 bytes in 5 packets\nOutbound: 342 bytes in 4 packets\nTCP Flags: 30\nClient TCP Flags: 26\nServer TCP Flags: 30\nTCP Window Max (In/Out): 16560/8192\nRetransmissions: 0 packets (0 bytes)\nDNS Query ID: 0\nDNS Query Type: 0\nDNS TTL Answer: 0\nFTP Return Code: 0.0\nFlow Duration: 0 ms\nDuration In/Out: 0/0\nThroughput Client→Server: 3984000 bytes/s\nThroughput Server→Client: 2736000 bytes/s\nPacket Size Distribution:\n  ≤128 bytes: 6 packets\n  128-256 bytes: 3 packets\n  256-512 bytes: 0 packets\n  512-1024 bytes: 0 packets\n  1024-1514 bytes: 0 packets\nTTL Range: 100-100\nPacket Length Range: 40-171\nShortest/Longest Packet: 40/171\nAttack: Benign'

In [95]:
test_content = test_docs[80_000]['content']
test_results = retrieve_similar_packets(test_content, collection, n=5)
test_metadata = test_results['metadatas']
test_attacks = [match_dict['Attack'] for match_dict in test_metadata[0]]
test_attack_counts = Counter(test_attacks)
test_attack_counts.most_common(1)

[-0.04087037965655327, -0.020381322130560875, -0.03110925480723381, -0.00647034402936697, -0.03313157334923744, -0.05401843786239624, -0.02347947657108307, -0.029236165806651115, 0.05314793065190315, 0.0038521408569067717, -0.07844346761703491, 0.02860700897872448, -0.06129780039191246, -0.018289247527718544, -0.03468776494264603, -0.037426989525556564, 0.02237490378320217, -0.08315383642911911, 0.03272686526179314, -0.03620027378201485, -0.009450693614780903, 0.023754317313432693, -0.03861020505428314, -0.09850215166807175, -0.05386893078684807, 0.0034558847546577454, -0.029220597818493843, 0.036198440939188004, -0.004991687834262848, -0.025976546108722687, 1.0081808795803227e-05, 0.036375973373651505, -0.09916147589683533, 0.015320827253162861, -0.009037282317876816, 0.04475220665335655, 0.10494452714920044, 0.03771550580859184, -0.04340944066643715, 0.05608333274722099, 0.09272707998752594, -0.10652230679988861, 0.05705270543694496, 0.12643401324748993, -0.03738673776388168, -0.0173

[('Benign', 5)]

In [96]:
test_docs[80_000]['metadata']['Attack']

'Benign'

In [72]:
def query_with_filters(collection, query_text, metadata_filters, n_results=5):
    """Query with both semantic search and metadata filters."""
    
    # Load embedding model
    model = SentenceTransformer('all-MiniLM-L6-v2')
    
    # Encode the query
    query_embedding = model.encode(query_text).tolist()
    
    # Convert flat filters to proper ChromaDB format
    # ChromaDB requires a single operator at the top level
    formatted_where = {"$and": []}
    
    for key, value in metadata_filters.items():
        if isinstance(value, dict):
            # This is already an operator (like {"$gt": 4.0})
            formatted_where["$and"].append({key: value})
        else:
            # This is a direct value match
            formatted_where["$and"].append({key: {"$eq": value}})
    
    # Search the collection with properly formatted filters
    results = collection.query(
        query_embeddings=[query_embedding],
        where=formatted_where,
        n_results=n_results
    )
    
    return results