In [1]:
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, InputExample, losses
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import BinaryClassificationEvaluator
import chromadb
from chromadb import Settings
from chromadb.utils import embedding_functions
import os
from collections import Counter
import kagglehub
import re
from imblearn.over_sampling import RandomOverSampler
from datasets import Dataset
import random
from mitreattack.stix20 import MitreAttackData
import requests

from data_loading.load_luflow import get_luflow
from data_loading.tools import reduce_mem_usage

vector_db_path = os.path.join(os.getcwd(), os.pardir, os.pardir, "data/vector_db")
client = chromadb.PersistentClient(path=vector_db_path, settings=Settings(allow_reset=True, 
                                                                          anonymized_telemetry=False))

use_luflow = False
use_uq = True
RESET_DB = True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_url = "https://raw.githubusercontent.com/mitre/cti/master/enterprise-attack/enterprise-attack.json"
# Download the data if it doesn't exist
if not os.path.exists('enterprise-attack.json'):
    response = requests.get(data_url)
    with open('enterprise-attack.json', 'wb') as f:
        f.write(response.content)
data = MitreAttackData('enterprise-attack.json')
techniques = data.get_techniques()
mitigations = data.get_all_mitigations_mitigating_all_techniques()
len(techniques), len(mitigations)

(799, 557)

In [6]:
# find scanning techniques
relevant_techniques = {}
attack_types = data['Attack'].unique()
for technique in techniques:
    for attack_type in attack_types:
        if attack_type.lower() in technique['name'].lower():
            relevant_techniques[technique['id']] = {
                'name': technique['name'],
                'description': technique['description'],
            }

print(len(relevant_techniques))

30


In [7]:
for id in relevant_techniques.keys():
    relevant_techniques[id]['mitigation'] = []
    if id in mitigations:
        mitigation = mitigations[id]
        relevant_techniques[id]['mitigation'] = []
        for mit in mitigation:
            mit_name = mit['object']['name']
            mit_description = mit['object']['description']
            if mit_name and mit_description:
                relevant_techniques[id]['mitigation'].append({
                    'name': mit_name,
                    'description': mit_description
                })

print(relevant_techniques)

{'attack-pattern--0042a9f5-f053-4769-b3ef-9ad018dfa298': {'name': 'Extra Window Memory Injection', 'description': "Adversaries may inject malicious code into process via Extra Window Memory (EWM) in order to evade process-based defenses as well as possibly elevate privileges. EWM injection is a method of executing arbitrary code in the address space of a separate live process. \n\nBefore creating a window, graphical Windows-based processes must prescribe to or register a windows class, which stipulate appearance and behavior (via windows procedures, which are functions that handle input/output of data).(Citation: Microsoft Window Classes) Registration of new windows classes can include a request for up to 40 bytes of EWM to be appended to the allocated memory of each instance of that class. This EWM is intended to store data specific to that window and has specific application programming interface (API) functions to set and get its value. (Citation: Microsoft GetWindowLong function) (

In [9]:
for technique in relevant_techniques:
    print(relevant_techniques[technique]['name'])
    for mitigation in relevant_techniques[technique]['mitigation']:
        print(mitigation)

Extra Window Memory Injection
{'name': 'Behavior Prevention on Endpoint', 'description': 'Use capabilities to prevent suspicious behavior patterns from occurring on endpoint systems. This could include suspicious process, file, API call, etc. behavior.'}
Password Guessing
{'name': 'Update Software', 'description': 'Perform regular software updates to mitigate exploitation risk.'}
{'name': 'Multi-factor Authentication', 'description': 'Use two or more pieces of evidence to authenticate to a system; such as username and password in addition to a token from a physical smart card or token generator.'}
{'name': 'Password Policies', 'description': 'Set and enforce secure password policies for accounts.'}
{'name': 'Account Use Policies', 'description': 'Configure features related to account use like login attempt lockouts, specific login times, etc.'}
Password Cracking
{'name': 'Password Policies', 'description': 'Set and enforce secure password policies for accounts.'}
{'name': 'Multi-factor

In [10]:
relevant_techniques

{'attack-pattern--0042a9f5-f053-4769-b3ef-9ad018dfa298': {'name': 'Extra Window Memory Injection',
  'description': "Adversaries may inject malicious code into process via Extra Window Memory (EWM) in order to evade process-based defenses as well as possibly elevate privileges. EWM injection is a method of executing arbitrary code in the address space of a separate live process. \n\nBefore creating a window, graphical Windows-based processes must prescribe to or register a windows class, which stipulate appearance and behavior (via windows procedures, which are functions that handle input/output of data).(Citation: Microsoft Window Classes) Registration of new windows classes can include a request for up to 40 bytes of EWM to be appended to the allocated memory of each instance of that class. This EWM is intended to store data specific to that window and has specific application programming interface (API) functions to set and get its value. (Citation: Microsoft GetWindowLong function)

In [8]:
# Create a list of documents with their metadata
documents = []
metadatas = []
ids = []

for attack_id, attack_info in relevant_techniques.items():
    # Process main description
    description = attack_info.get('description', '')
    name = attack_info.get('name', '')
    
    # Create a document with the full attack information
    document = f"Name: {name}\n\nDescription: {description}"
    
    # Add mitigation strategies if available
    mitigations = attack_info.get('mitigation', [])
    if mitigations:
        mitigation_text = "\n\nMitigations:\n"
        for mitigation in mitigations:
            mitigation_text += f"- {mitigation.get('name', '')}: {mitigation.get('description', '')}\n"
        document += mitigation_text
    
    # Add to our lists
    documents.append(document)
    metadatas.append({
        "attack_id": attack_id,
        "name": name,
        "document_type": "attack_pattern"
    })
    ids.append(attack_id)

# 2. Set up ChromaDB and the sentence transformer embedding function
# Initialize the embedding function using sentence-transformers
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
    model_name="all-MiniLM-L6-v2"  # You can use other models like "all-mpnet-base-v2" for better quality
)

if RESET_DB:
    # Reset the database if it exists
    client.reset()

# Create a collection for cybersecurity attacks
collection = client.create_collection(
    name="attack_patterns",
    embedding_function=sentence_transformer_ef,
    metadata={"description": "MITRE ATT&CK patterns and techniques"}
)

# 3. Add documents to the collection
collection.add(
    documents=documents,
    metadatas=metadatas,
    ids=ids
)

print(f"Added {len(documents)} attack patterns to the vector database.")

Added 30 attack patterns to the vector database.


In [11]:
# 4. Example queries to test the database
def query_attacks(query_text, n_results=3):
    """Search for relevant attack patterns based on the query text."""
    results = collection.query(
        query_texts=[query_text],
        n_results=n_results
    )
    
    print(f"Query: '{query_text}'")
    print(f"Found {len(results['documents'][0])} relevant results:\n")
    
    for i, (doc, metadata, distance) in enumerate(zip(
        results['documents'][0], 
        results['metadatas'][0],
        results['distances'][0]
    )):
        print(f"Result {i+1}: {metadata['name']} (ID: {metadata['attack_id']})")
        print(f"Relevance score: {1 - distance:.4f}")
        print(f"{doc[:300]}...\n")  # Print first 300 chars of the document

# Test with some example queries
print("\n=== EXAMPLE QUERIES ===\n")
query_attacks("password brute force attacks", 3)
query_attacks("process injection techniques", 3)
query_attacks("credential stealing from password managers", 2)
query_attacks("vulnerabilities in Windows systems", 3)


=== EXAMPLE QUERIES ===

Query: 'password brute force attacks'
Found 3 relevant results:

Result 1: Brute Force (ID: attack-pattern--a93494bb-4b80-4ea1-8695-3236a49916fd)
Relevance score: 0.2947
Name: Brute Force

Description: Adversaries may use brute force techniques to gain access to accounts when passwords are unknown or when password hashes are obtained.(Citation: TrendMicro Pawn Storm Dec 2020) Without knowledge of the password for an account or set of accounts, an adversary may syste...

Result 2: Password Policy Discovery (ID: attack-pattern--b6075259-dba3-44e9-87c7-e954f37ec0d5)
Relevance score: 0.0954
Name: Password Policy Discovery

Description: Adversaries may attempt to access detailed information about the password policy used within an enterprise network or cloud environment. Password policies are a way to enforce complex passwords that are difficult to guess or crack through [Brute Force](h...

Result 3: Password Cracking (ID: attack-pattern--1d24cdee-9ea2-4189-b08e-af

In [104]:
# 5. Function to use in a RAG pipeline
def retrieve_context(query, n_results=5):
    """
    Retrieve relevant context from the vector DB for use in a RAG pipeline.
    Returns a formatted string with the most relevant content.
    """
    results = collection.query(
        query_texts=[query],
        n_results=n_results
    )
    
    context = ""
    for i, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
        context += f"--- {metadata['name']} ---\n"
        context += f"{doc}\n\n"
    
    return context

# Example of how to use in a RAG pipeline with an LLM
def rag_pipeline(query):
    # 1. Retrieve relevant context
    context = retrieve_context(query)
    
    # 2. Format prompt for LLM
    prompt = f"""
    Based on the following cybersecurity information, please answer the question or provide information about: {query}
    
    Context information:
    {context}
    """
    
    # 3. Send to LLM (pseudocode)
    # response = call_llm_api(prompt)
    
    # 4. Return response
    # return response
    
    # For demonstration purposes, just return the prompt
    return "Ready to send to LLM:\n" + prompt

# Example usage of the RAG pipeline
print("\n=== RAG PIPELINE EXAMPLE ===\n")
print(rag_pipeline("How can I protect against password attacks?"))


=== RAG PIPELINE EXAMPLE ===

Ready to send to LLM:

    Based on the following cybersecurity information, please answer the question or provide information about: How can I protect against password attacks?
    
    Context information:
    --- Password Policy Discovery ---
Name: Password Policy Discovery

Description: Adversaries may attempt to access detailed information about the password policy used within an enterprise network or cloud environment. Password policies are a way to enforce complex passwords that are difficult to guess or crack through [Brute Force](https://attack.mitre.org/techniques/T1110). This information may help the adversary to create a list of common passwords and launch dictionary and/or brute force attacks which adheres to the policy (e.g. if the minimum password length should be 8, then not trying passwords such as 'pass123'; not checking for more than 3-4 passwords per account if the lockout is set to 6 as to not lock out accounts).

Password policies ca

In [18]:
data['Attack'].value_counts()

Attack
Benign            65765
DDoS              57505
DoS               47267
scanning           9947
Reconnaissance     6935
xss                6410
password           3071
injection          1809
Bot                 371
Brute Force         344
Infilteration       315
Exploits             77
Fuzzers              56
Backdoor             52
Generic              33
mitm                 19
ransomware           10
Analysis              7
Theft                 3
Shellcode             3
Worms                 1
Name: count, dtype: int64

In [None]:
# Path to the cached dataset
cache_path = os.path.expanduser("~/.cache/kagglehub/datasets")
data_path = os.path.join(cache_path, "aryashah2k/nfuqnidsv2-network-intrusion-detection-dataset/versions/1")

if not os.path.exists(data_path):
    # Download latest version
    data_path = kagglehub.dataset_download("aryashah2k/nfuqnidsv2-network-intrusion-detection-dataset")

data = pd.read_csv(os.path.join(data_path, "NF-UQ-NIDS-v2.csv"), nrows=200_000)

data = data[data["Label"] != 0]
data = reduce_mem_usage(data)
data.head()

Memory usage after optimization is: 20.35 MB
Decreased by 57.7%


  has_large_values = (abs_vals > 1e6).any()
  has_large_values = (abs_vals > 1e6).any()


Unnamed: 0,IPV4_SRC_ADDR,L4_SRC_PORT,IPV4_DST_ADDR,L4_DST_PORT,PROTOCOL,L7_PROTO,IN_BYTES,IN_PKTS,OUT_BYTES,OUT_PKTS,...,TCP_WIN_MAX_OUT,ICMP_TYPE,ICMP_IPV4_TYPE,DNS_QUERY_ID,DNS_QUERY_TYPE,DNS_TTL_ANSWER,FTP_COMMAND_RET_CODE,Label,Attack,Dataset
0,192.168.100.148,65389,192.168.100.7,80,6,7.0,420,3,0,0,...,0,35840,140,0,0,0,0.0,1,DoS,NF-BoT-IoT-v2
1,192.168.100.148,11154,192.168.100.5,80,6,7.0,280,2,40,1,...,0,0,0,0,0,0,0.0,1,DoS,NF-BoT-IoT-v2
7,192.168.1.31,54001,192.168.1.180,22,6,92.0,84,2,88,2,...,29200,0,0,0,0,0,0.0,1,scanning,NF-ToN-IoT-v2
8,192.168.100.147,33372,192.168.100.5,80,6,7.0,280,2,40,1,...,0,35840,140,0,0,0,0.0,1,DoS,NF-BoT-IoT-v2
9,192.168.1.31,37085,192.168.1.193,1863,6,0.0,44,1,40,1,...,0,0,0,0,0,0,0.0,1,scanning,NF-ToN-IoT-v2


In [3]:
data['Attack'].value_counts()

Attack
DDoS              57505
DoS               47267
scanning           9947
Reconnaissance     6935
xss                6410
password           3071
injection          1809
Bot                 371
Brute Force         344
Infilteration       315
Exploits             77
Fuzzers              56
Backdoor             52
Generic              33
mitm                 19
ransomware           10
Analysis              7
Theft                 3
Shellcode             3
Worms                 1
Name: count, dtype: int64

In [4]:
X = data.drop(columns=['Attack'])
y = data['Attack']
oversampler = RandomOverSampler(sampling_strategy='all')
X_over, y_over = oversampler.fit_resample(X, y)
print(f"Under-sampled class distribution: {Counter(y_over)}")

Under-sampled class distribution: Counter({'DoS': 57505, 'scanning': 57505, 'DDoS': 57505, 'xss': 57505, 'Bot': 57505, 'Reconnaissance': 57505, 'password': 57505, 'Fuzzers': 57505, 'injection': 57505, 'Theft': 57505, 'Brute Force': 57505, 'Infilteration': 57505, 'Exploits': 57505, 'Generic': 57505, 'Analysis': 57505, 'Backdoor': 57505, 'mitm': 57505, 'Shellcode': 57505, 'ransomware': 57505, 'Worms': 57505})


In [5]:
len(X_over)

1150100

In [6]:
# only keep 10% of the oversampled data
X_over = X_over.sample(frac=0.1)
y_over = y_over[X_over.index]
len(X_over)

115010

In [7]:
data = X_over
data['Attack'] = y_over
data.head()

  has_large_values = (abs_vals > 1e6).any()
  has_large_values = (abs_vals > 1e6).any()


Unnamed: 0,IPV4_SRC_ADDR,L4_SRC_PORT,IPV4_DST_ADDR,L4_DST_PORT,PROTOCOL,L7_PROTO,IN_BYTES,IN_PKTS,OUT_BYTES,OUT_PKTS,...,TCP_WIN_MAX_OUT,ICMP_TYPE,ICMP_IPV4_TYPE,DNS_QUERY_ID,DNS_QUERY_TYPE,DNS_TTL_ANSWER,FTP_COMMAND_RET_CODE,Label,Dataset,Attack
76221,192.168.1.31,40544,192.168.1.184,443,6,91.0,164,3,112,2,...,28960,0,0,0,0,0,0.0,1,NF-ToN-IoT-v2,DDoS
1120857,192.168.1.35,59404,192.168.1.1,53,17,0.0,65,1,65,1,...,0,0,0,64273,28,0,0.0,1,NF-ToN-IoT-v2,xss
295604,172.31.69.17,58972,18.219.211.138,8080,6,131.75,538,5,341,5,...,26883,0,0,0,0,0,0.0,1,NF-CSE-CIC-IDS2018-v2,Bot
125373,192.168.100.149,2618,192.168.100.7,80,6,7.0,280,2,0,0,...,0,0,0,0,0,0,0.0,1,NF-BoT-IoT-v2,DoS
388045,175.45.176.1,16974,149.171.126.12,80,6,7.0,760,10,1342,8,...,16383,52224,204,0,0,0,0.0,1,NF-UNSW-NB15-v2,Exploits


In [8]:
print(f"Under-sampled class distribution: {Counter(y_over)}")

Under-sampled class distribution: Counter({'Backdoor': 5894, 'ransomware': 5844, 'xss': 5839, 'Brute Force': 5833, 'Analysis': 5802, 'Fuzzers': 5788, 'mitm': 5762, 'Shellcode': 5758, 'DDoS': 5748, 'Bot': 5739, 'Infilteration': 5726, 'password': 5714, 'injection': 5714, 'Exploits': 5704, 'Generic': 5703, 'Worms': 5703, 'Reconnaissance': 5700, 'Theft': 5699, 'scanning': 5686, 'DoS': 5654})


In [9]:
data.columns

Index(['IPV4_SRC_ADDR', 'L4_SRC_PORT', 'IPV4_DST_ADDR', 'L4_DST_PORT',
       'PROTOCOL', 'L7_PROTO', 'IN_BYTES', 'IN_PKTS', 'OUT_BYTES', 'OUT_PKTS',
       'TCP_FLAGS', 'CLIENT_TCP_FLAGS', 'SERVER_TCP_FLAGS',
       'FLOW_DURATION_MILLISECONDS', 'DURATION_IN', 'DURATION_OUT', 'MIN_TTL',
       'MAX_TTL', 'LONGEST_FLOW_PKT', 'SHORTEST_FLOW_PKT', 'MIN_IP_PKT_LEN',
       'MAX_IP_PKT_LEN', 'SRC_TO_DST_SECOND_BYTES', 'DST_TO_SRC_SECOND_BYTES',
       'RETRANSMITTED_IN_BYTES', 'RETRANSMITTED_IN_PKTS',
       'RETRANSMITTED_OUT_BYTES', 'RETRANSMITTED_OUT_PKTS',
       'SRC_TO_DST_AVG_THROUGHPUT', 'DST_TO_SRC_AVG_THROUGHPUT',
       'NUM_PKTS_UP_TO_128_BYTES', 'NUM_PKTS_128_TO_256_BYTES',
       'NUM_PKTS_256_TO_512_BYTES', 'NUM_PKTS_512_TO_1024_BYTES',
       'NUM_PKTS_1024_TO_1514_BYTES', 'TCP_WIN_MAX_IN', 'TCP_WIN_MAX_OUT',
       'ICMP_TYPE', 'ICMP_IPV4_TYPE', 'DNS_QUERY_ID', 'DNS_QUERY_TYPE',
       'DNS_TTL_ANSWER', 'FTP_COMMAND_RET_CODE', 'Label', 'Dataset', 'Attack'],
      dtype='ob

In [None]:
def prepare_packet_documents(df):
    documents = []
    
    for idx, row in df.iterrows():
        # Create a structured text representation of each flow
        doc_parts = []
        doc_parts.append(f"Flow ID: {idx}")
        
        # Connection details
        doc_parts.append(f"Connection: {row.get('IPV4_SRC_ADDR', 'unknown')}:{row.get('L4_SRC_PORT', 'unknown')} → {row.get('IPV4_DST_ADDR', 'unknown')}:{row.get('L4_DST_PORT', 'unknown')}")
        
        # Protocol information
        doc_parts.append(f"Protocol: {row.get('PROTOCOL', 'unknown')} (L7: {row.get('L7_PROTO', 'unknown')})")
        
        # Volume metrics
        doc_parts.append(f"Traffic Volume: {row.get('IN_BYTES', 0) + row.get('OUT_BYTES', 0)} bytes total ({row.get('IN_PKTS', 0) + row.get('OUT_PKTS', 0)} packets)")
        doc_parts.append(f"Inbound: {row.get('IN_BYTES', 0)} bytes in {row.get('IN_PKTS', 0)} packets")
        doc_parts.append(f"Outbound: {row.get('OUT_BYTES', 0)} bytes in {row.get('OUT_PKTS', 0)} packets")
        
        # TCP specific information
        if row.get('PROTOCOL') == 6:  # TCP
            doc_parts.append(f"TCP Flags: {row.get('TCP_FLAGS', 'unknown')}")
            doc_parts.append(f"Client TCP Flags: {row.get('CLIENT_TCP_FLAGS', 'unknown')}")
            doc_parts.append(f"Server TCP Flags: {row.get('SERVER_TCP_FLAGS', 'unknown')}")
            doc_parts.append(f"TCP Window Max (In/Out): {row.get('TCP_WIN_MAX_IN', 'unknown')}/{row.get('TCP_WIN_MAX_OUT', 'unknown')}")
            
            # Retransmission stats
            doc_parts.append(f"Retransmissions: {row.get('RETRANSMITTED_IN_PKTS', 0) + row.get('RETRANSMITTED_OUT_PKTS', 0)} packets ({row.get('RETRANSMITTED_IN_BYTES', 0) + row.get('RETRANSMITTED_OUT_BYTES', 0)} bytes)")
        
        # ICMP specific information
        if row.get('PROTOCOL') == 1:  # ICMP
            doc_parts.append(f"ICMP Type: {row.get('ICMP_TYPE', 'unknown')}")
            doc_parts.append(f"ICMP IPv4 Type: {row.get('ICMP_IPV4_TYPE', 'unknown')}")
        
        # DNS specific information
        if row.get('DNS_QUERY_ID') is not None:
            doc_parts.append(f"DNS Query ID: {row.get('DNS_QUERY_ID', 'unknown')}")
            doc_parts.append(f"DNS Query Type: {row.get('DNS_QUERY_TYPE', 'unknown')}")
            doc_parts.append(f"DNS TTL Answer: {row.get('DNS_TTL_ANSWER', 'unknown')}")
        
        # FTP specific information
        if row.get('FTP_COMMAND_RET_CODE') is not None:
            doc_parts.append(f"FTP Return Code: {row.get('FTP_COMMAND_RET_CODE', 'unknown')}")
        
        # Timing information
        doc_parts.append(f"Flow Duration: {row.get('FLOW_DURATION_MILLISECONDS', 'unknown')} ms")
        doc_parts.append(f"Duration In/Out: {row.get('DURATION_IN', 'unknown')}/{row.get('DURATION_OUT', 'unknown')}")
        
        # Throughput
        doc_parts.append(f"Throughput Client→Server: {row.get('SRC_TO_DST_AVG_THROUGHPUT', 'unknown')} bytes/s")
        doc_parts.append(f"Throughput Server→Client: {row.get('DST_TO_SRC_AVG_THROUGHPUT', 'unknown')} bytes/s")
        
        # Packet size distribution
        doc_parts.append("Packet Size Distribution:")
        doc_parts.append(f"  ≤128 bytes: {row.get('NUM_PKTS_UP_TO_128_BYTES', 0)} packets")
        doc_parts.append(f"  128-256 bytes: {row.get('NUM_PKTS_128_TO_256_BYTES', 0)} packets")
        doc_parts.append(f"  256-512 bytes: {row.get('NUM_PKTS_256_TO_512_BYTES', 0)} packets")
        doc_parts.append(f"  512-1024 bytes: {row.get('NUM_PKTS_512_TO_1024_BYTES', 0)} packets")
        doc_parts.append(f"  1024-1514 bytes: {row.get('NUM_PKTS_1024_TO_1514_BYTES', 0)} packets")
        
        # TTL and packet lengths
        doc_parts.append(f"TTL Range: {row.get('MIN_TTL', 'unknown')}-{row.get('MAX_TTL', 'unknown')}")
        doc_parts.append(f"Packet Length Range: {row.get('MIN_IP_PKT_LEN', 'unknown')}-{row.get('MAX_IP_PKT_LEN', 'unknown')}")
        doc_parts.append(f"Shortest/Longest Packet: {row.get('SHORTEST_FLOW_PKT', 'unknown')}/{row.get('LONGEST_FLOW_PKT', 'unknown')}")
        
        # Classification information
        if 'Label' in row:
            doc_parts.append(f"Label: {row.get('Label', 'unknown')}")
        if 'Attack' in row:
            doc_parts.append(f"Attack: {row.get('Attack', 'unknown')}")
        if 'Dataset' in row:
            doc_parts.append(f"Dataset: {row.get('Dataset', 'unknown')}")
        
        # Combine all parts into a single document
        document = "\n".join(doc_parts)
        
        # Store the document with its metadata
        # Convert metadata to appropriate types to avoid issues
        metadata = {}
        for col in row.index:
            value = row.get(col)
            # Handle NaN values and convert to basic types that ChromaDB can handle
            if pd.isna(value):
                continue
            elif isinstance(value, (int, float, str, bool)):
                metadata[col] = value
            else:
                # Convert other types to strings
                metadata[col] = str(value)
        
        documents.append({
            "id": idx, 
            "content": document, 
            "metadata": metadata
        })
    
    return documents

In [12]:
def create_embeddings(documents, model):
    print("Extracting text content from documents...")
    # Extract just the text content
    texts = [doc["content"] for doc in tqdm(documents)]
    
    # Generate embeddings
    embeddings = model.encode(texts)
    
    print(f"Generated {len(embeddings)} embeddings")
    print("Adding embeddings to documents...")
    # Add embeddings to documents
    for i, doc in tqdm(enumerate(documents)):
        doc["embedding"] = embeddings[i]
    
    return documents

In [13]:
def store_in_vector_db(documents, client):
    # Create or get a collection
    collection = client.get_or_create_collection("packet_data")
    
    # Prepare data for insertion
    ids = [str(doc["id"]) for doc in documents]
    embeddings = [doc["embedding"].tolist() for doc in documents]
    metadatas = [doc["metadata"] for doc in documents]
    documents_text = [doc["content"] for doc in documents]
    
    # Set batch size within ChromaDB's limits
    batch_size = 5000
    
    # Add documents in batches
    total_docs = len(ids)
    for i in tqdm(range(0, total_docs, batch_size)):
        end_idx = min(i + batch_size, total_docs)
        
        # Create batch
        batch_ids = ids[i:end_idx]
        batch_embeddings = embeddings[i:end_idx]
        batch_metadatas = metadatas[i:end_idx]
        batch_documents = documents_text[i:end_idx]
        
        # Add batch to collection
        collection.add(
            ids=batch_ids,
            embeddings=batch_embeddings,
            metadatas=batch_metadatas,
            documents=batch_documents
        )
    
    return collection

In [14]:
def retrieve_similar_packets(query, collection, model, n=5):
    # Convert query to embedding
    query_embedding = model.encode(query).tolist()
    
    # Query the collection
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=n
    )
    
    return results

In [15]:
def create_enriched_document_features(df):
    # Create derived features that might help the model
    for idx, row in df.iterrows():
        # Calculate bytes per packet ratio
        if row.get('IN_PKTS', 0) > 0:
            df.at[idx, 'IN_BYTES_PER_PKT'] = row.get('IN_BYTES', 0) / row.get('IN_PKTS', 0)
        
        if row.get('OUT_PKTS', 0) > 0:
            df.at[idx, 'OUT_BYTES_PER_PKT'] = row.get('OUT_BYTES', 0) / row.get('OUT_PKTS', 0)
        
        # Calculate retransmission ratio
        if row.get('IN_PKTS', 0) > 0:
            df.at[idx, 'RETRANS_IN_RATIO'] = row.get('RETRANSMITTED_IN_PKTS', 0) / row.get('IN_PKTS', 0)
        
        # Calculate traffic asymmetry
        total_bytes = row.get('IN_BYTES', 0) + row.get('OUT_BYTES', 0)
        if total_bytes > 0:
            df.at[idx, 'TRAFFIC_ASYMMETRY'] = abs(row.get('IN_BYTES', 0) - row.get('OUT_BYTES', 0)) / total_bytes
    
    return df

In [16]:
def hybrid_search(query, collection, n=5, alpha=0.7):
    # Load model
    model = SentenceTransformer('all-mpnet-base-v2')
    query_embedding = model.encode(query).tolist()
    
    # Perform semantic search
    semantic_results = collection.query(
        query_embeddings=[query_embedding],
        n_results=n*2  # Get more results than needed for hybrid reranking
    )
    
    # Perform keyword search
    keyword_results = collection.query(
        query_texts=[query],
        n_results=n*2
    )
    
    # Combine results with weighted scoring
    combined_results = {}
    
    # Process semantic results
    for i, doc_id in enumerate(semantic_results['ids'][0]):
        score = 1.0 - (i / len(semantic_results['ids'][0]))  # Normalize score
        combined_results[doc_id] = alpha * score
    
    # Process keyword results
    for i, doc_id in enumerate(keyword_results['ids'][0]):
        score = 1.0 - (i / len(keyword_results['ids'][0]))  # Normalize score
        combined_results[doc_id] = combined_results.get(doc_id, 0) + (1-alpha) * score
    
    # Sort by combined score and get top n
    top_docs = sorted(combined_results.items(), key=lambda x: x[1], reverse=True)[:n]
    
    # Retrieve full documents for top results
    result_ids = [doc_id for doc_id, _ in top_docs]
    final_results = collection.get(ids=result_ids)
    
    return final_results

In [17]:
def evaluate_retrieval_quality(collection, queries, ground_truth):
    """
    Evaluate retrieval quality using precision, recall, and F1 score
    
    Args:
        collection: ChromaDB collection
        test_data: Test dataset
        queries: List of test queries
        ground_truth: Dict mapping queries to relevant document IDs
    
    Returns:
        Dict with precision, recall, and F1 metrics
    """
    results = {
        'precision': [],
        'recall': [],
        'f1': [],
    }
    
    for query, relevant_docs in zip(queries, ground_truth):
        retrieved_docs = retrieve_similar_packets(query, collection, n=10)
        retrieved_ids = retrieved_docs['ids'][0]
        
        # Calculate precision and recall
        relevant_retrieved = set(retrieved_ids).intersection(set(relevant_docs))
        precision = len(relevant_retrieved) / len(retrieved_ids) if retrieved_ids else 0
        recall = len(relevant_retrieved) / len(relevant_docs) if relevant_docs else 0
        
        # Calculate F1
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        results['precision'].append(precision)
        results['recall'].append(recall)
        results['f1'].append(f1)
    
    # Average the metrics
    for metric in results:
        results[metric] = sum(results[metric]) / len(results[metric])
    
    return results

In [21]:
train_documents[0]

{'id': 243266,
 'content': 'Flow ID: 243266\nConnection: 192.168.1.193:54518 → 192.168.1.33:80\nProtocol: 6 (L7: 7.0)\nTraffic Volume: 2640 bytes total (60 packets)\nInbound: 1440 bytes in 30 packets\nOutbound: 1200 bytes in 30 packets\nTCP Flags: 22\nClient TCP Flags: 2\nServer TCP Flags: 20\nTCP Window Max (In/Out): 8192/0\nRetransmissions: 0 packets (0 bytes)\nDNS Query ID: 0\nDNS Query Type: 0\nDNS TTL Answer: 0\nFTP Return Code: 0.0\nFlow Duration: 0 ms\nDuration In/Out: 0/0\nThroughput Client→Server: 11520000 bytes/s\nThroughput Server→Client: 9600000 bytes/s\nPacket Size Distribution:\n  ≤128 bytes: 60 packets\n  128-256 bytes: 0 packets\n  256-512 bytes: 0 packets\n  512-1024 bytes: 0 packets\n  1024-1514 bytes: 0 packets\nTTL Range: 128-128\nPacket Length Range: 40-48\nShortest/Longest Packet: 40/48\nAttack: Backdoor',
 'metadata': {'IPV4_SRC_ADDR': '192.168.1.193',
  'L4_SRC_PORT': 54518,
  'IPV4_DST_ADDR': '192.168.1.33',
  'L4_DST_PORT': 80,
  'PROTOCOL': 6,
  'L7_PROTO': 7

In [None]:
def fine_tune_embedding_model(train_documents, val_documents, label_field='Attack', model_name='all-MiniLM-L6-v2', epochs=1):
    """
    Fine-tune a sentence transformer model for network traffic classification
    Args:
        documents: List of document dictionaries with 'content' and 'metadata'
        label_field: Metadata field containing the label/attack type
        model_name: Base model to fine-tune
    Returns:
        Fine-tuned SentenceTransformer model
    """
    random.seed(42)
    def create_dataset_pairs(documents, label_field):
        """Create positive and negative pairs for training"""
        texts = [doc["content"] for doc in documents]
        labels = [doc["metadata"].get(label_field, "unknown") for doc in documents]
        
        # Group indices by label
        label_to_indices = {}
        for i, label in enumerate(labels):
            if label not in label_to_indices:
                label_to_indices[label] = []
            label_to_indices[label].append(i)
        
        sentence1 = []
        sentence2 = []
        scores = []
        
        # Create positive pairs (same label)
        for label, indices in label_to_indices.items():
            if len(indices) < 2:
                continue  # Skip labels with only one example
                
            # Sample pairs instead of creating all combinations
            for _ in range(min(1000, len(indices) * 2)):
                i, j = random.sample(indices, 2)
                sentence1.append(texts[i])
                sentence2.append(texts[j])
                scores.append(1.0)
        
        # Create negative pairs (different labels) - balanced with positive pairs
        num_negatives = len(sentence1)
        label_list = list(label_to_indices.keys())
        
        for _ in range(num_negatives):
            # Select two different labels
            if len(label_list) < 2:
                break
            label1, label2 = random.sample(label_list, 2)
            
            # Select one document from each label
            i = random.choice(label_to_indices[label1])
            j = random.choice(label_to_indices[label2])
            
            sentence1.append(texts[i])
            sentence2.append(texts[j])
            scores.append(0.0)
        
        return Dataset.from_dict({
            "sentence1": sentence1,
            "sentence2": sentence2,
            "score": scores
        })
    
    def create_input_examples_from_dataset(dataset):
        """Convert dataset entries to InputExample objects"""
        examples = []
        for i in range(len(dataset)):
            examples.append(
                InputExample(
                    texts=[dataset[i]["sentence1"], dataset[i]["sentence2"]],
                    label=float(dataset[i]["score"])
                )
            )
        return examples
    
    # Create datasets for training and validation
    train_dataset = create_dataset_pairs(train_documents, label_field)
    val_dataset = create_dataset_pairs(val_documents, label_field)
    
    # Convert validation dataset to InputExample format for the evaluator
    val_examples = create_input_examples_from_dataset(val_dataset)
    evaluator = BinaryClassificationEvaluator.from_input_examples(val_examples, name='val')
    
    # Initialize the model
    model = SentenceTransformer(model_name)
    
    # Create loss function
    loss = losses.CosineSimilarityLoss(model)
    
    print(f"Fine-tuning model with {len(train_dataset)} training pairs")
    training_args = SentenceTransformerTrainingArguments(
        output_dir='fine_tuned_model',
        num_train_epochs=epochs,
        logging_dir='logs',
        eval_strategy='steps',
        do_eval=True,
        do_train=True,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        load_best_model_at_end=True,
        metric_for_best_model='eval_val_cosine_f1',
    )
    
    # Create trainer
    trainer = SentenceTransformerTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    
    # Train the model
    trainer.train()
    
    return model

In [19]:
OVERWRITE_DB = True

if OVERWRITE_DB:
    client.reset()

test_size = 100
train_size = len(data) - test_size

if os.path.exists(vector_db_path) and not OVERWRITE_DB:
    collection = client.get_collection(collection_name)
    if use_luflow:
        test_data = data.iloc[train_size:]
    else:
        test_data = data.iloc[train_size:].drop(columns=["Label", "Dataset"])
    
else:
    # Convert to documents
    if use_luflow:
        train_data = data.iloc[:train_size].drop(columns=["label"])
        test_data = data.iloc[train_size:]
    else:
        train_data = data.iloc[:train_size].drop(columns=["Label", "Dataset"])
        test_data = data.iloc[train_size:].drop(columns=["Label", "Dataset"])
    
    train_documents = prepare_packet_documents_luflow(train_data) if use_luflow else prepare_packet_documents_uq(train_data)
    val_documents = train_documents[:int(len(train_documents) * 0.1)]
    train_documents = train_documents[int(len(train_documents) * 0.1):]
    model_name = 'all-MiniLM-L6-v2'
    fine_tuned_model = fine_tune_embedding_model(train_documents, val_documents, label_field='Attack', model_name=model_name, epochs=3)
    # Create embeddings
    embedded_docs = create_embeddings(train_documents, fine_tuned_model)
    collection = store_in_vector_db(embedded_docs, client)

Fine-tuning model with 40000 training pairs


  return F.linear(input, self.weight, self.bias)


RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Float'

In [42]:
test_documents = prepare_packet_documents_luflow(test_data) if use_luflow else prepare_packet_documents_uq(test_data)

In [43]:
example_doc = test_documents[1].copy()
example_doc['content'] = re.sub(r'Attack: .+', 'Attack: <unknown>', example_doc['content'])
example_doc['content']

'Flow ID: 154966\nConnection: 175.45.176.2:0 → 149.171.126.17:0\nProtocol: 99 (L7: 0.0)\nTraffic Volume: 200 bytes total (2 packets)\nInbound: 200 bytes in 2 packets\nOutbound: 0 bytes in 0 packets\nDNS Query ID: 0\nDNS Query Type: 0\nDNS TTL Answer: 0\nFTP Return Code: 0.0\nFlow Duration: 0 ms\nDuration In/Out: 0/0\nThroughput Client→Server: 1600000 bytes/s\nThroughput Server→Client: 0 bytes/s\nPacket Size Distribution:\n  ≤128 bytes: 2 packets\n  128-256 bytes: 0 packets\n  256-512 bytes: 0 packets\n  512-1024 bytes: 0 packets\n  1024-1514 bytes: 0 packets\nTTL Range: 254-254\nPacket Length Range: 0-100\nShortest/Longest Packet: 100/100\nAttack: <unknown>'

In [44]:
results = retrieve_similar_packets(example_doc['content'], collection, fine_tuned_model, n=5)
metadata = results['metadatas']
attacks = [match_dict['Attack'] for match_dict in metadata[0]]
attack_counts = Counter(attacks)
attack_counts.most_common(1)

[('Exploits', 4)]

In [45]:
example_doc['metadata']['Attack']

'Analysis'

In [46]:
def validate_vector_db(test_docs, model):
    correct = 0
    total = 0
    
    # Initialize counters for precision and recall calculation
    true_positives = Counter()
    true_negatives = Counter()
    false_positives = Counter()
    false_negatives = Counter()
    
    for doc in tqdm(test_docs):
        # Query the collection
        doc_copy = doc.copy()
        doc_copy['content'] = re.sub(r'Attack: .+', 'Attack: <unknown>', doc_copy['content'])
        results = retrieve_similar_packets(doc_copy['content'], collection, model, n=5)
        metadata = results['metadatas']
        attacks = [match_dict['Attack'] for match_dict in metadata[0]]
        attack_counts = Counter(attacks)
        
        # Determine most common predicted attack
        most_common_attack = attack_counts.most_common(1)[0][0]
        expected_attack = doc['metadata']['Attack']
        
        # Check if prediction is correct
        if most_common_attack == expected_attack:
            correct += 1
            # The most common attack is correct - count as true positive
            true_positives[expected_attack] += 1
        else:
            print(f"Expected: {expected_attack}, Predicted: {most_common_attack}")
            # The prediction was wrong - count as false positive
            false_positives[expected_attack] += 1
            
        # Check for other attacks in the results that weren't the most common
        # If the expected attack is in results but not the most common, it's missed
        if expected_attack in attacks and most_common_attack != expected_attack:
            false_negatives[expected_attack] += 1
        # If the expected attack is not in results at all, it's also missed
        elif expected_attack not in attacks:
            false_negatives[expected_attack] += 1
            
        total += 1
    precision = {}
    recall = {}
    f1 = {}

    # Calculate overall metrics
    for attack in true_positives:
        true_positives[attack] = true_positives[attack] / total
        false_positives[attack] = false_positives[attack] / total
        false_negatives[attack] = false_negatives[attack] / total

        # Calculate precision, recall, and F1 for each attack
        precision[attack] = true_positives[attack] / (true_positives[attack] + false_positives[attack]) if (true_positives[attack] + false_positives[attack]) > 0 else 0
        recall[attack] = true_positives[attack] / (true_positives[attack] + false_negatives[attack]) if (true_positives[attack] + false_negatives[attack]) > 0 else 0
        f1[attack] = 2 * (precision[attack] * recall[attack]) / (precision[attack] + recall[attack]) if (precision[attack] + recall[attack]) > 0 else 0
    
        print(f"Precision: {precision[attack]:.2f}, Recall: {recall[attack]:.2f}, F1: {f1[attack]:.2f}")
        print(f"True Positives: {true_positives[attack]}, False Positives: {false_positives[attack]}, False Negatives: {false_negatives[attack]}")
    print(f"Correct: {correct}, Total: {total}, Accuracy: {correct/total:.2f}")
    
    return correct / total
    
test_accuracy = validate_vector_db(test_documents, fine_tuned_model)
test_accuracy

 14%|█▍        | 14/100 [00:00<00:01, 67.41it/s]

Expected: scanning, Predicted: injection
Expected: Analysis, Predicted: Exploits
Expected: Backdoor, Predicted: password
Expected: Generic, Predicted: Exploits
Expected: Worms, Predicted: Generic
Expected: mitm, Predicted: injection
Expected: DDoS, Predicted: injection
Expected: DDoS, Predicted: password
Expected: scanning, Predicted: password


 28%|██▊       | 28/100 [00:00<00:01, 67.52it/s]

Expected: Generic, Predicted: Exploits
Expected: Worms, Predicted: Generic
Expected: Theft, Predicted: password
Expected: scanning, Predicted: injection
Expected: scanning, Predicted: password
Expected: injection, Predicted: password
Expected: DoS, Predicted: injection
Expected: Analysis, Predicted: Exploits
Expected: Generic, Predicted: Exploits


 44%|████▍     | 44/100 [00:00<00:00, 69.72it/s]

Expected: password, Predicted: injection
Expected: Generic, Predicted: Exploits
Expected: DDoS, Predicted: injection
Expected: password, Predicted: injection
Expected: Fuzzers, Predicted: Exploits
Expected: injection, Predicted: password
Expected: scanning, Predicted: injection
Expected: Theft, Predicted: injection
Expected: Theft, Predicted: injection


 58%|█████▊    | 58/100 [00:00<00:00, 67.63it/s]

Expected: ransomware, Predicted: password
Expected: Worms, Predicted: Generic
Expected: Analysis, Predicted: Exploits
Expected: DDoS, Predicted: Brute Force
Expected: Reconnaissance, Predicted: injection
Expected: mitm, Predicted: injection
Expected: Theft, Predicted: password
Expected: Infilteration, Predicted: Brute Force
Expected: DoS, Predicted: injection


 74%|███████▍  | 74/100 [00:01<00:00, 70.21it/s]

Expected: ransomware, Predicted: password
Expected: Analysis, Predicted: Exploits
Expected: scanning, Predicted: ransomware
Expected: Analysis, Predicted: Exploits
Expected: Worms, Predicted: Generic
Expected: Shellcode, Predicted: Exploits
Expected: scanning, Predicted: password
Expected: DDoS, Predicted: password
Expected: injection, Predicted: xss
Expected: Backdoor, Predicted: injection
Expected: password, Predicted: injection


 90%|█████████ | 90/100 [00:01<00:00, 71.47it/s]

Expected: Infilteration, Predicted: Brute Force
Expected: mitm, Predicted: password
Expected: Worms, Predicted: Generic
Expected: DoS, Predicted: injection
Expected: Worms, Predicted: Generic
Expected: ransomware, Predicted: password
Expected: Brute Force, Predicted: Infilteration
Expected: xss, Predicted: injection
Expected: Backdoor, Predicted: password
Expected: Theft, Predicted: injection
Expected: DDoS, Predicted: injection


100%|██████████| 100/100 [00:01<00:00, 69.73it/s]

Expected: Worms, Predicted: Generic
Expected: Worms, Predicted: Generic
Expected: password, Predicted: injection
Expected: DoS, Predicted: injection
Expected: Reconnaissance, Predicted: injection
Precision: 1.00, Recall: 1.00, F1: 1.00
True Positives: 0.07, False Positives: 0.0, False Negatives: 0.0
Precision: 1.00, Recall: 1.00, F1: 1.00
True Positives: 0.08, False Positives: 0.0, False Negatives: 0.0
Precision: 0.56, Recall: 0.56, F1: 0.56
True Positives: 0.05, False Positives: 0.04, False Negatives: 0.04
Precision: 0.86, Recall: 0.86, F1: 0.86
True Positives: 0.06, False Positives: 0.01, False Negatives: 0.01
Precision: 0.62, Recall: 0.62, F1: 0.62
True Positives: 0.05, False Positives: 0.03, False Negatives: 0.03
Precision: 0.60, Recall: 0.60, F1: 0.60
True Positives: 0.03, False Positives: 0.02, False Negatives: 0.02
Precision: 0.50, Recall: 0.50, F1: 0.50
True Positives: 0.03, False Positives: 0.03, False Negatives: 0.03
Correct: 37, Total: 100, Accuracy: 0.37





0.37

In [72]:
def query_with_filters(collection, query_text, metadata_filters, n_results=5):
    """Query with both semantic search and metadata filters."""
    
    # Load embedding model
    model = SentenceTransformer('all-MiniLM-L6-v2')
    
    # Encode the query
    query_embedding = model.encode(query_text).tolist()
    
    # Convert flat filters to proper ChromaDB format
    # ChromaDB requires a single operator at the top level
    formatted_where = {"$and": []}
    
    for key, value in metadata_filters.items():
        if isinstance(value, dict):
            # This is already an operator (like {"$gt": 4.0})
            formatted_where["$and"].append({key: value})
        else:
            # This is a direct value match
            formatted_where["$and"].append({key: {"$eq": value}})
    
    # Search the collection with properly formatted filters
    results = collection.query(
        query_embeddings=[query_embedding],
        where=formatted_where,
        n_results=n_results
    )
    
    return results

In [15]:
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

model_name = "ehristoforu/coolqwen-3b-it"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Create the text-generation pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1000)

Downloading shards: 100%|██████████| 2/2 [02:25<00:00, 72.53s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.68s/it]
Device set to use cuda:0


In [16]:
import json
import time

# Function to retrieve similar network packets for LLM context
def retrieve_context_for_llm(query_packet, collection, n=5):
    results = retrieve_similar_packets(query_packet, collection, n=n)
    context = "SIMILAR NETWORK TRAFFIC EXAMPLES:\n\n"
    if 'documents' in results and results['documents'][0]:
        for i, doc in enumerate(results['documents'][0]):
            distance = results['distances'][0][i] if 'distances' in results else "unknown"
            context += f"--- Example {i+1} (Similarity: {1/(1+distance):.2f}) ---\n"
            context += doc + "\n\n"
    return context

# Function to query an LLM using Hugging Face Transformers
def query_llm(prompt, model, max_retries=3):
    for attempt in range(max_retries):
        try:
            response = model(prompt, return_full_text=False)
            return response[0]['generated_text']
        except Exception as e:
            if attempt < max_retries - 1:
                sleep_time = 2 ** attempt
                print(f"Error querying LLM: {e}. Retrying in {sleep_time}s...")
                time.sleep(sleep_time)
            else:
                print(f"Failed to query LLM after {max_retries} attempts: {e}")
                return "{}"

In [29]:
def nids_rag_analysis(packet_content, collection, model):
    """
    Perform RAG-based network intrusion detection
    
    Args:
        packet_content: The formatted packet content to analyze
        collection: Vector database collection containing network traffic examples
        llm_client: Client for your LLM (e.g., OpenAI, Anthropic, etc.)
        model: SentenceTransformer model for embedding
    
    Returns:
        Dict with analysis results
    """
    # Retrieve similar examples as context
    context = retrieve_context_for_llm(packet_content, collection)
    
    # Construct prompt for the LLM
    prompt = f"""You are a network security expert analyzing network traffic for potential intrusions.
    
I'll provide you with:
1. A network packet to analyze
2. Examples of similar network traffic from our database with known classifications

Based on the examples and your expertise, determine:
- If this packet is malicious or benign
- If malicious, what type of attack it represents (e.g., DoS, Port Scan, DDoS, etc.)
- Confidence level (0-1)
- Explanation of your reasoning

Here is the packet to analyze:

{packet_content}

Here are similar packets from our database for context:

{context}

Provide your analysis in JSON format with the following fields:
{{
  "is_malicious": true/false,
  "attack_type": "attack name or null if benign",
  "confidence": 0.xx,
  "explanation": "your detailed reasoning"
}}

Return only valid JSON, with no additional text.
"""
    
    # Send to LLM and get response
    response = query_llm(prompt, model)
    # remove newlines from response
    response = response.replace("\n", "")

    # Extract JSON from the response using regex
    json_matches = re.findall(r'\{.*?\}', response, re.DOTALL)
    
    if json_matches:
        json_str = json_matches[0]
        try:
            analysis = json.loads(json_str)
            return analysis
        except json.JSONDecodeError:
            pass
    
    # Parse the JSON response
    try:
        analysis = json.loads(response)
        return analysis
    except json.JSONDecodeError:
        # Fallback if LLM doesn't return valid JSON
        return {
            "is_malicious": None,
            "attack_type": None,
            "confidence": 0,
            "explanation": "Error parsing LLM response",
            "raw_response": response
        }

In [30]:
nids_rag_analysis(example_doc['content'], collection, pipe)

{'is_malicious': True,
 'attack_type': 'scanning',
 'confidence': 0.99,
 'explanation': 'The packet has a high similarity score (0.99) with other benign scanning packets. It fits the pattern of a scanning attack, where the client initiates a TCP connection without establishing a session, sending small packets over a wide range of ports to identify open services on the target host.'}