In [17]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install langchain langchain-community langchain-ollama python-dotenv pydantic logging instructor

import os
import json
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
import instructor
from ollama import Client  # Added for raw response debugging

# Load environment variables
load_dotenv()

# --- Configuration ---
# Use environment variables for paths and settings
LLM_MODEL = os.getenv("LLM_MODEL", "llama3:8b")  # Changed default to a more reliable model for testing
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists early
os.makedirs(INPUT_DIR, exist_ok=True)

# Initialize LLM with Instructor for schema enforcement
llm = instructor.from_provider(f"ollama/{LLM_MODEL}", mode=instructor.Mode.JSON_SCHEMA)

# Added: Raw Ollama client for debugging
ollama_client = Client()

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(description="Affected asset, process, data flow, or entity.")
    stride_category: str = Field(description="One STRIDE category: Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege.")
    threat_description: str = Field(description="Clear, specific description of the threat.")
    mitigation_suggestion: str = Field(description="Practical, actionable mitigation.")
    impact: str = Field(description="Low/Medium/High based on potential damage.")
    likelihood: str = Field(description="Low/Medium/High based on exploitability.")
    references: list[str] = Field(description="Array of standard references (e.g., ['OWASP A01:2021', 'NIST SI-2']).")

class Threats(BaseModel):
    threats: list[Threat]

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing (if input is empty) ---
SAMPLE_DFD = {
    "external_entities": ["User", "Attacker"],
    "processes": ["Web Application", "Authentication Service"],
    "data_stores": ["User Database"],
    "data_flows": [
        {
            "from": "User",
            "to": "Web Application",
            "data": "Login Credentials",
            "protocol": "HTTP"
        },
        {
            "from": "Web Application",
            "to": "User Database",
            "data": "Query User Data",
            "protocol": "SQL"
        }
    ],
    "trust_boundaries": ["Internet to DMZ", "DMZ to Internal Network"]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:  # Added: Check for empty data
        logger.warning("--- DFD data is empty. Using sample DFD for testing ---")
        dfd_data = SAMPLE_DFD
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.warning(f"--- Input file not found at '{DFD_INPUT_PATH}'. Using sample DFD for testing ---")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty. Using sample DFD for testing.")
    dfd_data = SAMPLE_DFD
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Prompt Engineering for Threat Generation ---
threat_prompt_template = """
You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where applicable.
2. Describe threats considering trust boundaries, protocols, and potential attack vectors (e.g., injection, misconfiguration).
3. Suggest mitigations with references to standards (e.g., "NIST AC-6 for least privilege").
4. Assess impact (Low/Medium/High based on potential damage) and likelihood (Low/Medium/High based on exploitability).

For each threat, include:
- 'component_name': Affected asset, process, data flow, or entity.
- 'stride_category': One STRIDE category.
- 'threat_description': Clear, specific description (e.g., "Attacker intercepts unencrypted data in transit leading to disclosure").
- 'mitigation_suggestion': Practical, actionable mitigation (e.g., "Implement TLS 1.3 with certificate pinning").
- 'impact': Low/Medium/High.
- 'likelihood': Low/Medium/High.
- 'references': Array of strings (e.g., ["OWASP A01:2021", "NIST SI-2"]).

DFD Components:
---
{dfd_json}
---

Generate a JSON object with a key 'threats' (array of threat objects). Output ONLY the JSON, with no additional commentary or formatting.
"""

threat_prompt = ChatPromptTemplate.from_template(threat_prompt_template)

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM to generate STRIDE threats ---")
try:
    # Convert the loaded DFD dictionary back to a JSON string for the prompt
    dfd_json_string = json.dumps(dfd_data, indent=2)

    # Generate messages from the prompt template
    messages = threat_prompt.format_messages(dfd_json=dfd_json_string)

    # Added: Log the prompt for debugging
    logger.info(f"--- Prompt sent to LLM ---\n{messages[0].content}")

    # Added: Call raw Ollama for response debugging (before Instructor)
    raw_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": messages[0].content}])
    logger.info(f"--- Raw LLM Response ---\n{raw_response['message']['content']}")

    # Invoke the LLM with Instructor for structured output
    threats_obj = llm.chat.completions.create(
        messages=[{"role": "user", "content": messages[0].content}],
        response_model=Threats,
        max_retries=5  # Increased for better handling
    )

    threats_dict = threats_obj.model_dump()
    
    # Add metadata
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }
    
    # Validate the output against schema (Instructor already enforces, but double-check)
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        raise
    
    # Save the threats to a new file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)
        
    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to the LLM not returning a well-formed JSON object or an issue with the input data.")

2025-07-27 19:49:51,081 - INFO - Initializing ollama provider with model llama3:8b
2025-07-27 19:49:51,103 - INFO - Client initialized
2025-07-27 19:49:51,118 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-27 19:49:51,119 - INFO - --- DFD components loaded successfully ---
2025-07-27 19:49:51,119 - INFO - 
--- Invoking Local LLM to generate STRIDE threats ---
2025-07-27 19:49:51,120 - INFO - --- Prompt sent to LLM ---

You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where 

{
  "threats": [
    {
      "component_name": "U",
      "stride_category": "Information Disclosure",
      "threat_description": "Unencrypted data transmitted from U to CDN potentially disclosed",
      "mitigation_suggestion": "Implement end-to-end encryption; Use HTTPS protocol",
      "impact": "Medium",
      "likelihood": "High",
      "references": [
        "OWASP A01:2021",
        "NIST SI-2"
      ]
    },
    {
      "component_name": "CDN",
      "stride_category": "Tampering",
      "threat_description": "Malicious actor intercepts and modifies data in transit from CDN to LB",
      "mitigation_suggestion": "Implement integrity checking; Use digital signatures",
      "impact": "High",
      "likelihood": "Medium",
      "references": [
        "OWASP A03:2021",
        "MITRE CA-8"
      ]
    },
    {
      "component_name": "LB",
      "stride_category": "Elevation of Privilege",
      "threat_description": "Unprivileged actor gains elevated privileges on LB, potentia

In [52]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install openai pydantic logging python-dotenv langchain langchain_community langchain_huggingface faiss-cpu pypdf sentence-transformers requests

import os
import json
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
from openai import OpenAI
import requests  # Added for auto-download

# RAG specific imports
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Load environment variables
load_dotenv()

# --- Configuration ---
LLM_MODEL = os.getenv("LLM_MODEL", "llama-3.1-70b-instruct") # Updated model suggestion
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# RAG Configuration
RAG_DOCS_DIR = "rag_docs"
FAISS_INDEX_PATH = "faiss_index"

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure directories exist
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(RAG_DOCS_DIR, exist_ok=True)


def setup_rag_pipeline():
    """Initializes the RAG pipeline by creating or loading a FAISS vector store."""
    logger.info("--- Setting up RAG pipeline ---")
    
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

    if os.path.exists(FAISS_INDEX_PATH):
        logger.info(f"--- Loading existing FAISS index from '{FAISS_INDEX_PATH}' ---")
        db = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
    else:
        logger.info("--- No existing FAISS index found. Building a new one. ---")
        
        loaders = {
            "**/*.pdf": PyPDFLoader,
            "**/*.md": TextLoader,
            "**/*.txt": TextLoader
        }
        documents = []
        for glob, loader_cls in loaders.items():
            try:
                loader = DirectoryLoader(RAG_DOCS_DIR, glob=glob, loader_cls=loader_cls, show_progress=True, use_multithreading=True, silent_errors=True)
                documents.extend(loader.load())
            except Exception as e:
                logger.warning(f"Could not load files with pattern {glob} using {loader_cls.__name__}. Error: {e}")

        if not documents:
            logger.warning(f"--- No supported documents found in '{RAG_DOCS_DIR}'. Attempting to auto-download key resources ---")
            # Auto-download OWASP Top 10 2021 PDF (since 2025 not released as of July 2025)
            owasp_url = "https://owasp.org/Top10/assets/PDF/OWASP-Top-10-2021.pdf"
            try:
                response = requests.get(owasp_url)
                response.raise_for_status()
                owasp_path = os.path.join(RAG_DOCS_DIR, "owasp_top10_2021.pdf")
                with open(owasp_path, 'wb') as f:
                    f.write(response.content)
                logger.info(f"--- Downloaded OWASP Top 10 2021 PDF to '{owasp_path}' ---")
                # Reload PDF loader
                pdf_loader = DirectoryLoader(RAG_DOCS_DIR, glob="**/*.pdf", loader_cls=PyPDFLoader)
                documents.extend(pdf_loader.load())
            except Exception as e:
                logger.error(f"--- Failed to auto-download OWASP PDF: {e} ---")
                raise ValueError(f"No documents available for RAG. Please add files to '{RAG_DOCS_DIR}'.")
        
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
        docs = text_splitter.split_documents(documents)
        
        logger.info(f"--- Creating FAISS index from {len(docs)} document chunks. This may take a moment... ---")
        db = FAISS.from_documents(docs, embeddings)
        db.save_local(FAISS_INDEX_PATH)
        logger.info(f"--- FAISS index created and saved to '{FAISS_INDEX_PATH}' ---")
        
    return db

# --- Initialize RAG and OpenAI Client ---
try:
    rag_db = setup_rag_pipeline()
    client = OpenAI(
        base_url="https://api.scaleway.ai/4a8fd76b-8606-46e6-afe6-617ce8eeb948/v1",
        api_key=os.getenv("SCW_SECRET_KEY")
    )
    logger.info("--- OpenAI client initialized successfully ---")
except Exception as e:
    logger.error(f"--- Failed to initialize services: {e} ---")
    raise

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str
    stride_category: str
    threat_description: str
    mitigation_suggestion: str
    impact: str
    likelihood: str
    references: list[str]
    risk_score: str

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing ---
SAMPLE_DFD = {
    "external_entities": [{"name": "User"}],
    "processes": [{"name": "Web Application"}, {"name": "Authentication Service"}],
    "data_stores": [{"name": "User Database"}],
    "data_flows": [
        {"source": "User", "destination": "Web Application", "data_description": "Login Credentials", "protocol": "HTTPS"},
        {"source": "Web Application", "destination": "User Database", "data_description": "Query User Data", "protocol": "SQL"}
    ],
    "trust_boundaries": [{"name": "Internet to DMZ"}]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:
        logger.warning(f"DFD file at '{DFD_INPUT_PATH}' is empty. Using sample DFD for demonstration.")
        dfd_data = SAMPLE_DFD
except FileNotFoundError:
    logger.warning(f"DFD file not found at '{DFD_INPUT_PATH}'. Using sample DFD for demonstration.")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError as e:
    logger.error(f"FATAL: Error decoding JSON from '{DFD_INPUT_PATH}': {e}")
    exit(1)
except Exception as e:
    logger.error(f"FATAL: Error loading DFD: {e}")
    exit(1)

# --- Enhanced Prompting Strategy ---

# **FIX 1: Define STRIDE categories explicitly to ensure systematic coverage.**
# Parametrized: Load from config or file if needed
STRIDE_DEFINITIONS = {
    "S": ("Spoofing", "Illegitimately accessing systems or data by impersonating a user, process, or component."),
    "T": ("Tampering", "Unauthorized modification of data, either in transit or at rest."),
    "R": ("Repudiation", "A user or system denying that they performed an action, often due to a lack of sufficient proof (e.g., logs)."),
    "I": ("Information Disclosure", "Exposing sensitive information to unauthorized individuals."),
    "D": ("Denial of Service", "Preventing legitimate users from accessing a system or service."),
    "E": ("Elevation of Privilege", "A user or process gaining rights beyond their authorized level.")
}

# Optional: Load custom STRIDE from file
stride_config_path = "stride_config.json"
if os.path.exists(stride_config_path):
    with open(stride_config_path, 'r') as f:
        STRIDE_DEFINITIONS = json.load(f)
    logger.info("--- Loaded custom STRIDE definitions from 'stride_config.json' ---")

# **FIX 2: Create a highly specific prompt template focused on a SINGLE STRIDE category.**
# This prevents generic responses and forces the model to generate relevant, accurate threats.
threat_prompt_template_specific_rag = """
You are a cybersecurity architect specializing in threat modeling using the STRIDE methodology.
Your task is to generate 1-2 specific threats for a given DFD component, focusing ONLY on a single STRIDE category.

**DFD Component to Analyze:**
{component_info}

**STRIDE Category to Focus On:**
- **{stride_category} ({stride_name}):** {stride_definition}

**Security Context from Knowledge Base (for accuracy):**
'''
{rag_context}
'''

**Instructions:**
1.  Generate 1-2 distinct and realistic threats for the component that fall **strictly** under the '{stride_name}' category.
2.  **Be specific.** Relate the threat directly to the component's type and details. For a database, a Spoofing threat is a spoofed connection, not user impersonation. For a data flow, a Tampering threat is a Man-in-the-Middle attack.
3.  Use the provided Security Context to create specific descriptions, **actionable mitigations**, and accurate references (e.g., CWE, OWASP Cheat Sheets). Do not invent references.
4.  Provide a realistic risk assessment (Impact, Likelihood, Score).
5.  Output ONLY a valid JSON object with a single key "threats", containing a list of threat objects. Do not include any other text or commentary.

**JSON Threat Object Schema:**
{{
  "component_name": "string (the name of the component being analyzed)",
  "stride_category": "{stride_category}",
  "threat_description": "string (Specific to the component and STRIDE category)",
  "mitigation_suggestion": "string (Actionable and specific)",
  "impact": "Low, Medium, or High",
  "likelihood": "Low, Medium, or High",
  "references": ["list of strings, e.g., 'OWASP A01:2021', 'CWE-89'"],
  "risk_score": "Critical, High, Medium, or Low"
}}
"""

retry_prompt_addition = " Generate at least one threat if realistically applicable, even if minor."

# Function to calculate risk_score
def calculate_risk_score(impact, likelihood):
    if impact == "High" and likelihood in ["Medium", "High"]:
        return "Critical"
    elif (impact == "High" and likelihood == "Low") or (impact == "Medium" and likelihood == "High"):
        return "High"
    elif (impact == "Medium" and likelihood in ["Medium", "Low"]) or (impact == "Low" and likelihood == "High"):
        return "Medium"
    else:
        return "Low"

# --- Main Invocation Logic ---
logger.info("\n--- Invoking LLM with RAG to systematically generate STRIDE threats ---")
all_threats = []
try:
    components_to_analyze = []
    for key, value in dfd_data.items():
        if isinstance(value, list) and value:
            for item in value:
                # Ensure component has a name for better identification
                if isinstance(item, dict) and item.get("name"):
                    components_to_analyze.append({"type": key, "details": item})
                elif isinstance(item, dict): # Fallback for components without a 'name' field
                    components_to_analyze.append({"type": key, "details": item})


    # **FIX 3: Iterate through each component AND each STRIDE category.**
    # This loop structure ensures every category is considered for every component.
    for component in components_to_analyze:
        component_str = json.dumps(component)
        component_name = component.get("details", {}).get("name", component_str)
        logger.info(f"\n--- Analyzing component: {component_name} ---")

        retrieved_docs = rag_db.similarity_search(component_str, k=5)  # Increased to 5 for broader context
        rag_context = "\n---\n".join([doc.page_content for doc in retrieved_docs])
        logger.info("--- Retrieved RAG context for component ---")

        for cat_letter, (cat_name, cat_def) in STRIDE_DEFINITIONS.items():
            logger.info(f"--- Generating threats for STRIDE category: {cat_name} ---")
            
            prompt = threat_prompt_template_specific_rag.format(
                component_info=component_str,
                rag_context=rag_context,
                stride_category=cat_letter,
                stride_name=cat_name,
                stride_definition=cat_def
            )
            
            retry_count = 0
            max_retries = 1  # Retry once if no threats
            threats = []
            while retry_count <= max_retries and not threats:
                try:
                    if retry_count > 0:
                        prompt += retry_prompt_addition  # Add retry instruction
                    
                    response = client.chat.completions.create(
                        model=LLM_MODEL,
                        messages=[{"role": "user", "content": prompt}],
                        response_format={"type": "json_object"},
                        max_tokens=2048,
                        temperature=0.4 # Slightly lower temp for more focused output
                    )
                    
                    response_content = response.choices[0].message.content
                    generated_data = json.loads(response_content)
                    
                    # Ensure the response is a dict with a 'threats' key which is a list
                    if isinstance(generated_data, dict) and isinstance(generated_data.get("threats"), list):
                        threats = generated_data["threats"]
                        # Add component name if missing from LLM response
                        for threat in threats:
                            if 'component_name' not in threat or not threat['component_name']:
                                threat['component_name'] = component_name
                        logger.info(f"--- Successfully generated {len(threats)} threat(s) for category {cat_name} ---")
                    else:
                        logger.warning(f"--- LLM response for {cat_name} on {component_name} had unexpected structure. ---")
                        logger.debug(f"Raw Response: {response_content}")

                except (json.JSONDecodeError, AttributeError) as e:
                    logger.warning(f"--- Could not parse LLM response for {cat_name} on {component_name}: {e} ---")
                except Exception as e:
                    logger.error(f"--- An API error occurred for {cat_name} on {component_name}: {e} ---")
                
                retry_count += 1

            all_threats.extend(threats)

    # Deduplication: Remove exact duplicates based on description
    unique_threats = []
    seen_descriptions = set()
    for threat in all_threats:
        desc = threat.get('threat_description', '')
        if desc not in seen_descriptions:
            seen_descriptions.add(desc)
            # Recalculate risk_score
            threat['risk_score'] = calculate_risk_score(threat.get('impact', 'Low'), threat.get('likelihood', 'Low'))
            unique_threats.append(threat)
    all_threats = unique_threats
    logger.info(f"--- Deduplicated threats: {len(all_threats)} unique threats remaining ---")

    # --- Final Processing and Validation ---
    risk_order = {"Critical": 4, "High": 3, "Medium": 2, "Low": 1, "Informational": 0}
    all_threats.sort(key=lambda t: risk_order.get(t.get('risk_score', 'Low'), 0), reverse=True)

    final_output = {
        "threats": all_threats,
        "metadata": {
            "timestamp": datetime.now().isoformat(),
            "source_dfd": os.path.basename(DFD_INPUT_PATH),
            "llm_model": LLM_MODEL,
            "rag_index": FAISS_INDEX_PATH
        }
    }

    try:
        validated_output = ThreatsOutput(**final_output)
        logger.info("--- Final JSON output validated successfully against schema ---")
    except ValidationError as ve:
        logger.error(f"--- FINAL JSON VALIDATION FAILED: {ve} ---")
        
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(final_output, f, indent=2)

    logger.info("\n--- LLM RAG Output (Identified Threats) ---")
    # print(json.dumps(final_output, indent=2))
    logger.info(f"\n--- Identified {len(all_threats)} threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during the threat generation process ---")
    logger.error(f"Error: {e}", exc_info=True)

2025-07-28 18:09:52,560 - INFO - --- Setting up RAG pipeline ---
2025-07-28 18:09:52,561 - INFO - Use pytorch device_name: mps
2025-07-28 18:09:52,562 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
2025-07-28 18:09:55,354 - INFO - --- Loading existing FAISS index from 'faiss_index' ---
2025-07-28 18:09:55,443 - INFO - --- OpenAI client initialized successfully ---
2025-07-28 18:09:55,444 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-28 18:09:55,445 - INFO - 
--- Invoking LLM with RAG to systematically generate STRIDE threats ---
2025-07-28 18:09:55,445 - INFO - 
--- Analyzing component: {"type": "data_flows", "details": {"source": "U", "destination": "CDN", "data_description": "", "protocol": "HTTPS"}} ---
2025-07-28 18:09:55,463 - INFO - --- Retrieved RAG context for component ---
2025-07-28 18:09:55,463 - INFO - --- Generating threats for STRIDE category: Spoofing ---
2025-07-28 18:10:22,703 - INFO - HTTP Re

In [56]:
# --- Dependencies ---
# pip install sentence-transformers scikit-learn pydantic logging python-dotenv requests

import os
import json
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.cluster import DBSCAN
import requests

# --- Configuration ---
load_dotenv()
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_INPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))
REFINED_THREATS_OUTPUT_PATH = os.getenv("REFINED_THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "refined_threats.json"))
CONTROLS_INPUT_PATH = os.getenv("CONTROLS_INPUT_PATH", os.path.join(INPUT_DIR, "controls.json"))
RAG_DOCS_DIR = os.getenv("RAG_DOCS_DIR", "rag_docs")

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure directories exist
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(RAG_DOCS_DIR, exist_ok=True)

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(..., description="Standardized name of the component or data flow")
    stride_category: str = Field(..., pattern="^[STRIDE]$", description="STRIDE category (S, T, R, I, D, E)")
    threat_description: str = Field(..., description="Detailed description of the threat")
    mitigation_suggestion: str = Field(..., description="Actionable mitigation specific to the threat")
    impact: str = Field(..., pattern="^(Low|Medium|High)$", description="Impact level")
    likelihood: str = Field(..., pattern="^(Low|Medium|High)$", description="Likelihood level")
    references: list[str] = Field(..., description="List of references (e.g., CWE, OWASP)")
    risk_score: str = Field(..., pattern="^(Critical|High|Medium|Low)$", description="Derived risk score")
    justification: str = Field(..., description="Rationale for impact and likelihood ratings")
    exploitability: str = Field(..., pattern="^(Low|Medium|High)$", description="Ease of exploitation")
    mitigation_maturity: str = Field(..., pattern="^(Immature|Mature|Advanced)$", description="Maturity of mitigation controls")
    risk_statement: str = Field(..., description="Business-contextualized risk description")

class RefinedThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Helper Functions ---
def load_dfd_components():
    """Load DFD components to provide full data flow details."""
    logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
    try:
        with open(DFD_INPUT_PATH, 'r') as f:
            dfd_data = json.load(f)
        data_flows = dfd_data.get("data_flows", [])
        if not data_flows:
            logger.error("--- No data flows found in DFD ---")
            raise ValueError("Invalid DFD: No data flows")
        return data_flows
    except Exception as e:
        logger.error(f"--- Failed to load DFD components: {e} ---")
        raise

def load_controls():
    """Load client-provided controls to suppress threats."""
    logger.info(f"--- Loading controls from '{CONTROLS_INPUT_PATH}' ---")
    try:
        if os.path.exists(CONTROLS_INPUT_PATH):
            with open(CONTROLS_INPUT_PATH, 'r') as f:
                return json.load(f)
        return {"https_enabled": False, "tls_version": "1.2", "mtls_enabled": False}
    except Exception as e:
        logger.warning(f"--- Failed to load controls, using defaults: {e} ---")
        return {"https_enabled": False, "tls_version": "1.2", "mtls_enabled": False}

def calculate_risk_score(impact, likelihood):
    """Calculate risk score based on impact and likelihood."""
    risk_matrix = {
        ("High", "High"): "Critical",
        ("High", "Medium"): "Critical",
        ("High", "Low"): "High",
        ("Medium", "High"): "High",
        ("Medium", "Medium"): "Medium",
        ("Medium", "Low"): "Medium",
        ("Low", "High"): "Medium",
        ("Low", "Medium"): "Low",
        ("Low", "Low"): "Low"
    }
    return risk_matrix.get((impact, likelihood), "Low")

def assess_exploitability(threat, dfd_data):
    """Assess exploitability based on component exposure and protocol."""
    component = threat["component_name"]
    protocol = next((flow["protocol"] for flow in dfd_data 
                     if f"{flow['source']} to {flow['destination']}" == component), "Unknown")
    
    if "Public Zone" in component or "U to" in component:
        return "High"
    if "TLS" in protocol or "HTTPS" in protocol:
        return "Medium"
    return "Low"

def assess_mitigation_maturity(mitigation):
    """Assess maturity of mitigation based on specificity and implementation ease."""
    if "mTLS" in mitigation or "WAF" in mitigation or "rate limiting" in mitigation:
        return "Mature"
    if "logging" in mitigation or "audit" in mitigation:
        return "Immature"
    if "end-to-end encryption" in mitigation or "certificate pinning" in mitigation:
        return "Advanced"
    return "Mature"

def standardize_component_name(threat, valid_components):
    """Standardize component names to match DFD format."""
    original_name = threat["component_name"]
    valid_component_map = {f"{flow['source']} to {flow['destination']}": f"{flow['source']} to {flow['destination']}" for flow in valid_components}
    
    normalized = original_name.replace("Data Flow from ", "").replace("HTTPS Data Flow", "").replace("data flow", "").strip()
    normalized = " ".join(normalized.split()).replace(" to ", " to ")
    
    if normalized in valid_component_map:
        return valid_component_map[normalized]
    for valid_name in valid_component_map.values():
        if valid_name.lower() in normalized.lower():
            logger.info(f"--- Matched '{original_name}' to '{valid_name}' via partial match ---")
            return valid_name
    logger.warning(f"--- Component name '{original_name}' not found in DFD; retaining original ---")
    return original_name

def generate_justification(threat, dfd_data):
    """Generate tailored justification for impact and likelihood."""
    component = threat["component_name"]
    stride = threat["stride_category"]
    protocol = next((flow["protocol"] for flow in dfd_data 
                     if f"{flow['source']} to {flow['destination']}" == component), "Unknown")
    
    impact_reason = f"Impact rated {threat['impact']} because "
    if "DB_P" in component:
        impact_reason += "of potential exposure of sensitive customer data, leading to regulatory fines or reputational damage."
    elif threat["impact"] == "High":
        impact_reason += "of potential for severe business disruption or data breach."
    elif threat["impact"] == "Medium":
        impact_reason += "of moderate disruption or partial data exposure."
    else:
        impact_reason += "of minimal operational impact."
    
    likelihood_reason = f"Likelihood rated {threat['likelihood']} because "
    if "Public Zone" in component or "U to" in component:
        likelihood_reason += "component is internet-facing, increasing attack surface."
    elif "Internal Core" in component or "Data Zone" in component:
        likelihood_reason += "component is internal, reducing exposure."
    else:
        likelihood_reason += "of moderate attack surface."
    if "TLS" in protocol or "HTTPS" in protocol:
        likelihood_reason += f" Secure protocols (e.g., {protocol}) reduce exploitability."
    
    return f"{impact_reason} {likelihood_reason}"

def generate_risk_statement(threat, industry="Generic"):
    """Generate a business-contextualized risk statement."""
    impact_map = {
        "High": "significant financial loss (e.g., >$500K), regulatory fines, or reputational damage",
        "Medium": "moderate financial loss (e.g., $50K-$500K) or operational disruption",
        "Low": "minimal financial or operational impact"
    }
    component = threat["component_name"]
    risk = f"Risk of {threat['threat_description'].lower()} on {component} could lead to {impact_map[threat['impact']]}. "
    if industry == "Finance":
        risk += "This may violate PCI-DSS or SEC regulations."
    elif industry == "Healthcare":
        risk += "This may violate HIPAA or GDPR regulations."
    if threat["mitigation_maturity"] == "Mature":
        risk += f" Existing {threat['mitigation_suggestion'].lower()} reduces residual risk."
    return risk

def suppress_threats(threats, controls):
    """Suppress or downgrade threats based on implemented controls."""
    suppressed = []
    for threat in threats:
        downgrade = False
        if controls.get("https_enabled") and "HTTPS" in threat["component_name"] and "TLS" in threat["mitigation_suggestion"]:
            if controls.get("tls_version") >= "1.3" and "MitM" in threat["threat_description"]:
                threat["likelihood"] = "Low"
                threat["risk_score"] = calculate_risk_score(threat["impact"], threat["likelihood"])
                threat["justification"] += " Downgraded due to robust TLS 1.3 implementation."
                downgrade = True
            if controls.get("mtls_enabled") and "spoof" in threat["threat_description"].lower():
                logger.info(f"--- Suppressing {threat['component_name']} ({threat['stride_category']}) due to mTLS ---")
                continue
        suppressed.append(threat)
        if downgrade:
            logger.info(f"--- Downgraded likelihood for {threat['component_name']} ({threat['stride_category']}) ---")
    return suppressed

def deduplicate_threats(threats, similarity_threshold=0.8):
    """Deduplicate threats using clustering and description/mitigation similarity."""
    logger.info("--- Starting threat deduplication ---")
    model = SentenceTransformer('all-mpnet-base-v2')
    
    # Combine description and mitigation for similarity
    combined_texts = [f"{threat['threat_description']} {threat['mitigation_suggestion']}" for threat in threats]
    embeddings = model.encode(combined_texts, convert_to_tensor=True).cpu().numpy()
    
    # Cluster threats using DBSCAN
    clustering = DBSCAN(eps=1 - similarity_threshold, min_samples=1, metric="cosine").fit(embeddings)
    labels = clustering.labels_
    
    # Group threats by cluster, component, and STRIDE
    groups = {}
    for idx, label in enumerate(labels):
        key = (label, threats[idx]["component_name"], threats[idx]["stride_category"])
        if key not in groups:
            groups[key] = []
        groups[key].append(idx)
    
    # Merge clusters
    deduplicated_threats = []
    for key, indices in groups.items():
        if len(indices) == 1:
            deduplicated_threats.append(threats[indices[0]])
        else:
            primary_threat = threats[indices[0]]
            combined_references = set(primary_threat["references"])
            for idx in indices[1:]:
                combined_references.update(threats[idx]["references"])
            primary_threat["references"] = list(combined_references)
            # Select most detailed mitigation
            mitigations = [threats[i]["mitigation_suggestion"] for i in indices]
            primary_threat["mitigation_suggestion"] = max(mitigations, key=len)
            deduplicated_threats.append(primary_threat)
            logger.info(f"--- Merged {len(indices)} threats for {primary_threat['component_name']} ({primary_threat['stride_category']}) ---")
    
    logger.info(f"--- Reduced {len(threats)} threats to {len(deduplicated_threats)} ---")
    return deduplicated_threats

# --- Main Refinement Logic ---
def refine_threats():
    """Refine threats by deduplicating, standardizing, and enhancing with risks."""
    logger.info(f"--- Loading threats from '{THREATS_INPUT_PATH}' ---")
    try:
        with open(THREATS_INPUT_PATH, 'r') as f:
            threat_data = json.load(f)
        threats = threat_data.get("threats", [])
        if not threats:
            logger.error("--- No threats found in input file ---")
            raise ValueError("Empty threats list")
    except Exception as e:
        logger.error(f"--- Failed to load threats: {e} ---")
        raise

    # Load DFD and controls
    dfd_data = load_dfd_components()
    controls = load_controls()
    industry = os.getenv("CLIENT_INDUSTRY", "Generic")

    # Step 1: Standardize component names
    for threat in threats:
        threat["component_name"] = standardize_component_name(threat, dfd_data)

    # Step 2: Suppress threats based on controls
    threats = suppress_threats(threats, controls)

    # Step 3: Deduplicate threats
    deduplicated_threats = deduplicate_threats(threats)

    # Step 4: Refine ratings and add metadata
    refined_threats = []
    for threat in deduplicated_threats:
        # Adjust likelihood for internal components
        if any(zone in threat["component_name"] for zone in ["Internal Core", "Data Zone", "Management Zone"]):
            if threat["likelihood"] == "Medium":
                threat["likelihood"] = "Low"
                logger.debug(f"--- Detected internal zone for {threat['component_name']} ---")
                logger.info(f"--- Adjusted likelihood to Low for {threat['component_name']} ---")
        # Recalculate risk score
        threat["risk_score"] = calculate_risk_score(threat["impact"], threat["likelihood"])
        # Add exploitability and mitigation maturity
        threat["exploitability"] = assess_exploitability(threat, dfd_data)
        threat["mitigation_maturity"] = assess_mitigation_maturity(threat["mitigation_suggestion"])
        # Add justification and risk statement
        threat["justification"] = generate_justification(threat, dfd_data)
        threat["risk_statement"] = generate_risk_statement(threat, industry)
        refined_threats.append(threat)

    # Step 5: Sort by risk score
    risk_order = {"Critical": 4, "High": 3, "Medium": 2, "Low": 1}
    refined_threats.sort(key=lambda t: risk_order.get(t["risk_score"], 0), reverse=True)

    # Step 6: Generate final output
    final_output = {
        "threats": refined_threats,
        "metadata": {
            "timestamp": datetime.now().isoformat(),
            "source_dfd": os.path.basename(DFD_INPUT_PATH),
            "source_threats": os.path.basename(THREATS_INPUT_PATH),
            "refined_threat_count": len(refined_threats),
            "original_threat_count": len(threats),
            "industry_context": industry
        }
    }

    # Validate output
    try:
        validated_output = RefinedThreatsOutput(**final_output)
        logger.info("--- Final refined JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- Final JSON validation failed: {ve} ---")
        raise

    # Save output
    with open(REFINED_THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(final_output, f, indent=2)
    logger.info(f"--- Refined {len(refined_threats)} threats saved to '{REFINED_THREATS_OUTPUT_PATH}' ---")

    # Generate summary report
    summary = {
        "total_threats": len(refined_threats),
        "critical_count": sum(1 for t in refined_threats if t["risk_score"] == "Critical"),
        "high_count": sum(1 for t in refined_threats if t["risk_score"] == "High"),
        "medium_count": sum(1 for t in refined_threats if t["risk_score"] == "Medium"),
        "low_count": sum(1 for t in refined_threats if t["risk_score"] == "Low"),
        "prioritization_recommendation": (
            "Address Critical and High-risk threats within 30 days. Implement mitigations such as mTLS, "
            "rate limiting, and robust logging to reduce risk. Review Medium and Low risks for long-term mitigation."
        )
    }
    summary_path = os.path.join(INPUT_DIR, "threat_summary.json")
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    logger.info(f"--- Summary report saved to '{summary_path}' ---")

    # Export to CSV
    import pandas as pd
    pd.DataFrame(final_output["threats"]).to_csv(os.path.join(INPUT_DIR, "threats.csv"), index=False)
    logger.info(f"--- CSV report saved to '{os.path.join(INPUT_DIR, 'threats.csv')}' ---")

# --- Execute ---
if __name__ == "__main__":
    try:
        refine_threats()
    except Exception as e:
        logger.error(f"--- Refinement process failed: {e} ---", exc_info=True)

2025-07-28 20:14:07,270 - INFO - --- Loading threats from './output/identified_threats.json' ---
2025-07-28 20:14:07,271 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-28 20:14:07,271 - INFO - --- Loading controls from './output/controls.json' ---
2025-07-28 20:14:07,271 - INFO - --- Matched 'CDN to LB Data Flow' to 'CDN to LB' via partial match ---
2025-07-28 20:14:07,272 - INFO - --- Matched 'CDN to LB Data Flow' to 'CDN to LB' via partial match ---
2025-07-28 20:14:07,272 - INFO - --- Matched 'CDN to LB Data Flow' to 'CDN to LB' via partial match ---
2025-07-28 20:14:07,272 - INFO - --- Matched 'CDN to LB Data Flow' to 'CDN to LB' via partial match ---
2025-07-28 20:14:07,272 - INFO - --- Matched 'CDN to LB Data Flow' to 'CDN to LB' via partial match ---
2025-07-28 20:14:07,272 - INFO - --- Matched 'CDN to LB Data Flow' to 'CDN to LB' via partial match ---
2025-07-28 20:14:07,272 - INFO - --- Matched 'CDN to LB Data Flow' to 'CDN to LB' via parti

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2025-07-28 20:14:10,207 - INFO - --- Merged 2 threats for U to CDN (S) ---
2025-07-28 20:14:10,207 - INFO - --- Merged 2 threats for U to CDN (D) ---
2025-07-28 20:14:10,207 - INFO - --- Merged 2 threats for CDN to LB (I) ---
2025-07-28 20:14:10,207 - INFO - --- Merged 2 threats for LB to WS (R) ---
2025-07-28 20:14:10,208 - INFO - --- Merged 2 threats for LB to WS (D) ---
2025-07-28 20:14:10,208 - INFO - --- Merged 2 threats for LB to WS (E) ---
2025-07-28 20:14:10,208 - INFO - --- Merged 2 threats for WS to DB_P (I) ---
2025-07-28 20:14:10,208 - INFO - --- Merged 2 threats for WS to MQ (S) ---
2025-07-28 20:14:10,208 - INFO - --- Merged 2 threats for WS to MQ (I) ---
2025-07-28 20:14:10,209 - INFO - --- Merged 2 threats for WS to MQ (D) ---
2025-07-28 20:14:10,209 - INFO - --- Merged 2 threats for WS to MQ (E) ---
2025-07-28 20:14:10,209 - INFO - --- Merged 2 threats for WRK to MQ (S) ---
2025-07-28 20:14:10,209 - INFO - --- Merged 2 threats for WRK to MQ (I) ---
2025-07-28 20:14:10,