In [82]:
import os
import json
from langchain_core.prompts import ChatPromptTemplate
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
import instructor
from ollama import Client
import PyPDF2
import glob
from openai import OpenAI
import docx  # Added for DOCX support

# Load environment variables
load_dotenv()

# --- Configuration ---
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "scaleway").lower()  # Default to 'ollama', can be set to 'scaleway'
LLM_MODEL = os.getenv("LLM_MODEL", "llama-3.3-70b-instruct")
SCW_API_URL = os.getenv("SCW_API_URL", "https://api.scaleway.ai/4a8fd76b-8606-46e6-afe6-617ce8eeb948/v1")
SCW_SECRET_KEY = os.getenv("SCW_SECRET_KEY")
INPUT_DIR = os.getenv("INPUT_DIR", "./input_documents")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./output")
DFD_OUTPUT_PATH = os.getenv("DFD_OUTPUT_PATH", os.path.join(OUTPUT_DIR, "dfd_components.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Initialize LLM Client ---
def initialize_llm_client():
    if LLM_PROVIDER == "scaleway":
        if not SCW_SECRET_KEY:
            raise ValueError("SCW_SECRET_KEY environment variable is required for Scaleway API.")
        try:
            client = instructor.from_openai(OpenAI(base_url=SCW_API_URL, api_key=SCW_SECRET_KEY))
            logger.info("--- Scaleway OpenAI client initialized successfully ---")
            return client, "scaleway"
        except Exception as e:
            logger.error(f"--- Failed to initialize Scaleway client: {e} ---")
            raise
    else:  # Default to Ollama
        try:
            raw_client = Client()  # Raw Ollama client for debugging
            # Patch the Ollama client with instructor for structured output
            instructor_client = instructor.patch(Client())
            logger.info("--- Ollama client initialized successfully ---")
            return raw_client, instructor_client, "ollama"
        except Exception as e:
            logger.error(f"--- Failed to initialize Ollama client: {e} ---")
            raise

# --- DFD Schema for Validation ---
class DataFlow(BaseModel):
    source: str = Field(description="Source component of the data flow (e.g., 'U' for User).")
    destination: str = Field(description="Destination component of the data flow (e.g., 'CDN').")
    data_description: str = Field(description="Description of data being transferred (e.g., 'User session tokens').")
    data_classification: str = Field(description="Classification like 'Confidential', 'PII', or 'Public'.")
    protocol: str = Field(description="Protocol used (e.g., 'HTTPS', 'JDBC/ODBC over TLS').")
    authentication_mechanism: str = Field(description="Authentication method (e.g., 'JWT in Header').")

class DFDComponents(BaseModel):
    project_name: str = Field(description="Name of the project (e.g., 'Web Application Security Model').")
    project_version: str = Field(description="Version of the project (e.g., '1.1').")
    industry_context: str = Field(description="Industry context (e.g., 'Finance').")
    external_entities: list[str] = Field(description="List of external entities (e.g., ['U', 'Attacker']).")
    assets: list[str] = Field(description="List of assets like data stores (e.g., ['DB_P', 'DB_B']).")
    processes: list[str] = Field(description="List of processes (e.g., ['CDN', 'LB', 'WS']).")
    trust_boundaries: list[str] = Field(description="List of trust boundaries (e.g., ['Public Zone to Edge Zone']).")
    data_flows: list[DataFlow] = Field(description="List of data flows between components.")

class DFDOutput(BaseModel):
    dfd: DFDComponents
    metadata: dict

# --- Sample Input for Testing (if no documents are found) ---
SAMPLE_DOCUMENT_CONTENT = """
System: Web Application Security Model, Version 1.1, Finance Industry
External Entities: User (U), External Attacker
Assets: Profile Database (DB_P), Billing Database (DB_B)
Processes: Content Delivery Network (CDN), Load Balancer (LB), Web Server (WS), Message Queue (MQ), Worker (WRK), Admin Service (ADM), Admin Portal (ADM_P)
Trust Boundaries: Public Zone to Edge Zone, Edge Zone to Application DMZ, Application DMZ to Internal Core, Internal Core to Data Zone, Management Zone to Application DMZ
Data Flows:
- From User to CDN: User session tokens and requests for static assets, Confidential, HTTPS, JWT in Header
- From CDN to LB: Cached content and user requests, Confidential, HTTPS, mTLS
- From WS to DB_P: User profile data including names and email addresses, PII, JDBC/ODBC over TLS, Database Credentials from Secrets Manager
"""

# --- Load and Parse Documents ---
def load_documents(input_dir):
    logger.info(f"--- Loading documents from '{input_dir}' ---")
    documents = []
    # Expanded glob patterns to include more file types: TXT, PDF, DOCX, MD
    file_patterns = [
        "*.[tT][xX][tT]",      # TXT files (case-insensitive)
        "*.[pP][dD][fF]",      # PDF files
        "*.[dD][oO][cC][xX]",  # DOCX files
        "*.[mM][dD]"           # Markdown files (optional addition for more types)
    ]
    all_files = []
    for pattern in file_patterns:
        all_files.extend(glob.glob(os.path.join(input_dir, pattern)))
    
    for file_path in all_files:
        try:
            ext = os.path.splitext(file_path)[1].lower()
            if ext == ".txt" or ext == ".md":  # Treat MD like TXT
                with open(file_path, 'r', encoding='utf-8') as f:
                    documents.append(f.read())
                logger.info(f"Loaded text-based file: {file_path}")
            elif ext == ".pdf":
                with open(file_path, 'rb') as f:
                    pdf_reader = PyPDF2.PdfReader(f)
                    text = "".join(page.extract_text() for page in pdf_reader.pages if page.extract_text())
                    documents.append(text)
                logger.info(f"Loaded PDF file: {file_path}")
            elif ext == ".docx":
                doc = docx.Document(file_path)
                text = "\n".join([para.text for para in doc.paragraphs if para.text])
                documents.append(text)
                logger.info(f"Loaded DOCX file: {file_path}")
        except Exception as e:
            logger.warning(f"Failed to load {file_path}: {e}")
    if not documents:
        logger.warning("--- No valid documents found. Using sample document content ---")
        documents = [SAMPLE_DOCUMENT_CONTENT]
    return documents

# --- Prompt Engineering for Document Extraction ---
extract_prompt_template = """
You are a senior cybersecurity analyst specializing in threat modeling. Your task is to extract structured information from one or more input documents describing a system and transform it into a comprehensive and accurate JSON object representing a Data Flow Diagram (DFD).

Your analysis must be meticulous. Follow these reasoning steps precisely:

1.  **Identify Core Components**: First, perform a full scan of the document(s) to identify and list all high-level components. This includes:
    * `project_name`, `project_version`, and `industry_context`.
    * `external_entities`: Any user, actor, or system outside the primary application boundary.
    * `processes`: The distinct computational components or services that handle data.
    * `assets`: The data stores, such as databases, object storage buckets, or message queues.
    * `trust_boundaries`: The defined boundaries separating zones of different trust levels.

2.  **Systematically Extract ALL Data Flows**: This is the most critical step. You must identify every single flow of data mentioned or implied in the documents. Scrutinize sections like "Use Cases," "Data Flow Diagrams," "Architecture," and "Technology Stack" to find them. Create a data flow entry for each of the following interaction types:
    * **External-to-Process**: Flows from an `external_entity` to an internal `process` (e.g., user submitting credentials, uploading a file).
    * **Process-to-External**: Flows from an internal `process` to an `external_entity` (e.g., returning results, sending a session token).
    * **Process-to-Asset**: Flows where a `process` reads from or writes to a data store `asset` (e.g., "Authentication Service reads from UsersDB," "Analysis process writes to ResultsDB"). These are essential and must not be omitted.
    * **Process-to-Process**: Flows between internal `processes` (e.g., "API Gateway routes request to Authentication Service").

3.  **Detail and Classify Each Flow**: For every data flow you identify, you must accurately populate all its attributes: `source`, `destination`, `data_description`, `data_classification`, `protocol`, and `authentication_mechanism`.
    * **Data Classification Rules**: Apply strict classification.
        * **Confidential**: Use for any data that, if exposed, could harm the organization or its users. This includes, but is not limited to: credentials, session tokens (JWTs), API keys, SAML assertions, Personally Identifiable Information (PII), health information (PHI), financial data, and proprietary business logic.
        * **Public**: Use ONLY for data that is explicitly intended for public consumption and carries no security risk if intercepted (e.g., a list of available public APIs). **Authentication-related data is never public.**
    * If information for a field (like `protocol` or `authentication_mechanism`) is not explicitly stated, infer it from the context (e.g., a web service likely uses HTTPS, a database connection likely uses JDBC/ODBC over TLS) and make a note of this in the `assumptions` key in the metadata.


Input Documents:
---
{documents}
---

4.  **Final Review**: Before generating the final output, review the generated list of data flows against the use cases in the source document. Ensure that every major action described in the use cases is represented by one or more data flows in your output.


Output ONLY the JSON, with no additional commentary or formatting.
"""

extract_prompt = ChatPromptTemplate.from_template(extract_prompt_template)

# --- Invocation and Output ---
logger.info("\n--- Starting Pre-Filter for Document Extraction ---")
try:
    # Initialize LLM client
    if LLM_PROVIDER == "scaleway":
        client, client_type = initialize_llm_client()
    else:
        raw_client, instructor_client, client_type = initialize_llm_client()

    # Load documents
    documents = load_documents(INPUT_DIR)
    documents_combined = "\n--- Document Separator ---\n".join(documents)

    # Generate messages from the prompt template
    messages = extract_prompt.format_messages(documents=documents_combined)

    # Log the prompt for debugging
    logger.info(f"--- Prompt sent to LLM ---\n{messages[0].content}")

    if client_type == "scaleway":
        # Use instructor client for Scaleway
        dfd_obj = client.chat.completions.create(
            model=LLM_MODEL,
            messages=[{"role": "user", "content": messages[0].content}],
            response_model=DFDComponents,
            max_retries=5
        )
        # Log raw response for debugging
        raw_client = OpenAI(base_url=SCW_API_URL, api_key=SCW_SECRET_KEY)
        raw_response = raw_client.chat.completions.create(
            model=LLM_MODEL,
            messages=[{"role": "user", "content": messages[0].content}],
            response_format={"type": "json_object"}
        )
        logger.info(f"--- Raw Scaleway Response ---\n{raw_response.choices[0].message.content}")
        # Log token usage for Scaleway
        if hasattr(raw_response, 'usage'):
            prompt_tokens = raw_response.usage.prompt_tokens or 'N/A'
            completion_tokens = raw_response.usage.completion_tokens or 'N/A'
            total_tokens = raw_response.usage.total_tokens or 'N/A'
            logger.info(f"--- Token Usage for Scaleway ---")
            logger.info(f"Input Tokens: {prompt_tokens}")
            logger.info(f"Output Tokens: {completion_tokens}")
            logger.info(f"Total Tokens: {total_tokens}")

        
    else:
        # Use instructor client for Ollama
        dfd_obj = instructor_client.chat.completions.create(
            model=LLM_MODEL,
            messages=[{"role": "user", "content": messages[0].content}],
            response_model=DFDComponents,
            max_retries=5
        )
        # Log raw response for debugging
        raw_response = raw_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": messages[0].content}])
        logger.info(f"--- Raw Ollama Response ---\n{raw_response['message']['content']}")
        # Log Token Count and Performance
        prompt_tokens = raw_response.get('prompt_eval_count', 'N/A')
        prompt_duration_ns = raw_response.get('prompt_eval_duration', 0)
        response_tokens = raw_response.get('eval_count', 'N/A')
        response_duration_ns = raw_response.get('eval_duration', 0)
        prompt_duration_s = f"{prompt_duration_ns / 1_000_000_000:.2f}s" if prompt_duration_ns else "N/A"
        response_duration_s = f"{response_duration_ns / 1_000_000_000:.2f}s" if response_duration_ns else "N/A"
        logger.info(f"--- Token Usage & Performance ---")
        logger.info(f"Input Tokens: {prompt_tokens} (processed in {prompt_duration_s})")
        logger.info(f"Output Tokens: {response_tokens} (generated in {response_duration_s})")

    dfd_dict = dfd_obj.model_dump()
    
    # Add metadata
    output_dict = {
        "dfd": dfd_dict,
        "metadata": {
            "timestamp": datetime.now().isoformat(),
            "source_documents": glob.glob(os.path.join(INPUT_DIR, "*.[tT][xX][tT]")) + glob.glob(os.path.join(INPUT_DIR, "*.[pP][dD][fF]")) + glob.glob(os.path.join(INPUT_DIR, "*.[dD][oO][cC][xX]")) + glob.glob(os.path.join(INPUT_DIR, "*.[mM][dD]")),
            "assumptions": [],
            "llm_provider": LLM_PROVIDER
        }
    }
    
    # Validate the output against schema
    try:
        validated = DFDOutput(**output_dict)
        logger.info("--- JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        raise
    
    # Save the DFD components to a file
    with open(DFD_OUTPUT_PATH, 'w') as f:
        json.dump(output_dict, f, indent=2)
        
    logger.info("\n--- LLM Output (DFD Components) ---")
    print(json.dumps(output_dict, indent=2))
    logger.info(f"\n--- DFD components successfully saved to '{DFD_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during document extraction ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to the LLM not returning a well-formed JSON object or an issue with the input documents.")

2025-07-29 13:01:17,771 - INFO - 
--- Starting Pre-Filter for Document Extraction ---
2025-07-29 13:01:17,789 - INFO - --- Scaleway OpenAI client initialized successfully ---
2025-07-29 13:01:17,790 - INFO - --- Loading documents from './input_documents' ---
2025-07-29 13:01:17,801 - INFO - Loaded DOCX file: ./input_documents/Design Document.docx
2025-07-29 13:01:17,802 - INFO - --- Prompt sent to LLM ---

You are a senior cybersecurity analyst specializing in threat modeling. Your task is to extract structured information from one or more input documents describing a system and transform it into a comprehensive and accurate JSON object representing a Data Flow Diagram (DFD).

Your analysis must be meticulous. Follow these reasoning steps precisely:

1.  **Identify Core Components**: First, perform a full scan of the document(s) to identify and list all high-level components. This includes:
    * `project_name`, `project_version`, and `industry_context`.
    * `external_entities`: Any 

{
  "dfd": {
    "project_name": "HealthData Insights Platform",
    "project_version": "1.0",
    "industry_context": "Healthcare",
    "external_entities": [
      "Researcher",
      "Organizational Admin",
      "External Identity Provider"
    ],
    "assets": [
      "UsersDB",
      "RawDataBucket",
      "ProcessedDataBucket",
      "ResultsDB"
    ],
    "processes": [
      "User Authentication",
      "Data Ingestion",
      "Data Analysis",
      "Results Presentation"
    ],
    "trust_boundaries": [
      "External User Zone -> Application Frontend",
      "Internet -> VPC Public Subnet",
      "VPC Public Subnet -> VPC Private Subnet",
      "Application -> Data Stores"
    ],
    "data_flows": [
      {
        "source": "Researcher",
        "destination": "User Authentication",
        "data_description": "Login credentials",
        "data_classification": "Confidential",
        "protocol": "HTTPS",
        "authentication_mechanism": "MFA"
      },
      {
        "

In [17]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install langchain langchain-community langchain-ollama python-dotenv pydantic logging instructor

import os
import json
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
import instructor
from ollama import Client  # Added for raw response debugging

# Load environment variables
load_dotenv()

# --- Configuration ---
# Use environment variables for paths and settings
LLM_MODEL = os.getenv("LLM_MODEL", "llama3:8b")  # Changed default to a more reliable model for testing
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists early
os.makedirs(INPUT_DIR, exist_ok=True)

# Initialize LLM with Instructor for schema enforcement
llm = instructor.from_provider(f"ollama/{LLM_MODEL}", mode=instructor.Mode.JSON_SCHEMA)

# Added: Raw Ollama client for debugging
ollama_client = Client()

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(description="Affected asset, process, data flow, or entity.")
    stride_category: str = Field(description="One STRIDE category: Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege.")
    threat_description: str = Field(description="Clear, specific description of the threat.")
    mitigation_suggestion: str = Field(description="Practical, actionable mitigation.")
    impact: str = Field(description="Low/Medium/High based on potential damage.")
    likelihood: str = Field(description="Low/Medium/High based on exploitability.")
    references: list[str] = Field(description="Array of standard references (e.g., ['OWASP A01:2021', 'NIST SI-2']).")

class Threats(BaseModel):
    threats: list[Threat]

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing (if input is empty) ---
SAMPLE_DFD = {
    "external_entities": ["User", "Attacker"],
    "processes": ["Web Application", "Authentication Service"],
    "data_stores": ["User Database"],
    "data_flows": [
        {
            "from": "User",
            "to": "Web Application",
            "data": "Login Credentials",
            "protocol": "HTTP"
        },
        {
            "from": "Web Application",
            "to": "User Database",
            "data": "Query User Data",
            "protocol": "SQL"
        }
    ],
    "trust_boundaries": ["Internet to DMZ", "DMZ to Internal Network"]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:  # Added: Check for empty data
        logger.warning("--- DFD data is empty. Using sample DFD for testing ---")
        dfd_data = SAMPLE_DFD
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.warning(f"--- Input file not found at '{DFD_INPUT_PATH}'. Using sample DFD for testing ---")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty. Using sample DFD for testing.")
    dfd_data = SAMPLE_DFD
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Prompt Engineering for Threat Generation ---
threat_prompt_template = """
You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where applicable.
2. Describe threats considering trust boundaries, protocols, and potential attack vectors (e.g., injection, misconfiguration).
3. Suggest mitigations with references to standards (e.g., "NIST AC-6 for least privilege").
4. Assess impact (Low/Medium/High based on potential damage) and likelihood (Low/Medium/High based on exploitability).

For each threat, include:
- 'component_name': Affected asset, process, data flow, or entity.
- 'stride_category': One STRIDE category.
- 'threat_description': Clear, specific description (e.g., "Attacker intercepts unencrypted data in transit leading to disclosure").
- 'mitigation_suggestion': Practical, actionable mitigation (e.g., "Implement TLS 1.3 with certificate pinning").
- 'impact': Low/Medium/High.
- 'likelihood': Low/Medium/High.
- 'references': Array of strings (e.g., ["OWASP A01:2021", "NIST SI-2"]).

DFD Components:
---
{dfd_json}
---

Generate a JSON object with a key 'threats' (array of threat objects). Output ONLY the JSON, with no additional commentary or formatting.
"""

threat_prompt = ChatPromptTemplate.from_template(threat_prompt_template)

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM to generate STRIDE threats ---")
try:
    # Convert the loaded DFD dictionary back to a JSON string for the prompt
    dfd_json_string = json.dumps(dfd_data, indent=2)

    # Generate messages from the prompt template
    messages = threat_prompt.format_messages(dfd_json=dfd_json_string)

    # Added: Log the prompt for debugging
    logger.info(f"--- Prompt sent to LLM ---\n{messages[0].content}")

    # Added: Call raw Ollama for response debugging (before Instructor)
    raw_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": messages[0].content}])
    logger.info(f"--- Raw LLM Response ---\n{raw_response['message']['content']}")

    # Invoke the LLM with Instructor for structured output
    threats_obj = llm.chat.completions.create(
        messages=[{"role": "user", "content": messages[0].content}],
        response_model=Threats,
        max_retries=5  # Increased for better handling
    )

    threats_dict = threats_obj.model_dump()
    
    # Add metadata
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }
    
    # Validate the output against schema (Instructor already enforces, but double-check)
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        raise
    
    # Save the threats to a new file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)
        
    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to the LLM not returning a well-formed JSON object or an issue with the input data.")

2025-07-27 19:49:51,081 - INFO - Initializing ollama provider with model llama3:8b
2025-07-27 19:49:51,103 - INFO - Client initialized
2025-07-27 19:49:51,118 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-27 19:49:51,119 - INFO - --- DFD components loaded successfully ---
2025-07-27 19:49:51,119 - INFO - 
--- Invoking Local LLM to generate STRIDE threats ---
2025-07-27 19:49:51,120 - INFO - --- Prompt sent to LLM ---

You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where 

{
  "threats": [
    {
      "component_name": "U",
      "stride_category": "Information Disclosure",
      "threat_description": "Unencrypted data transmitted from U to CDN potentially disclosed",
      "mitigation_suggestion": "Implement end-to-end encryption; Use HTTPS protocol",
      "impact": "Medium",
      "likelihood": "High",
      "references": [
        "OWASP A01:2021",
        "NIST SI-2"
      ]
    },
    {
      "component_name": "CDN",
      "stride_category": "Tampering",
      "threat_description": "Malicious actor intercepts and modifies data in transit from CDN to LB",
      "mitigation_suggestion": "Implement integrity checking; Use digital signatures",
      "impact": "High",
      "likelihood": "Medium",
      "references": [
        "OWASP A03:2021",
        "MITRE CA-8"
      ]
    },
    {
      "component_name": "LB",
      "stride_category": "Elevation of Privilege",
      "threat_description": "Unprivileged actor gains elevated privileges on LB, potentia

In [None]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install openai pydantic logging python-dotenv langchain langchain_community langchain_huggingface faiss-cpu pypdf sentence-transformers requests

import os
import json
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
from openai import OpenAI
import requests  # Added for auto-download
import fitz
import io
import pytesseract

# RAG specific imports
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import (
    DirectoryLoader, PyPDFLoader, TextLoader, Docx2txtLoader, CSVLoader, BSHTMLLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document


# Load environment variables
load_dotenv()

# --- Configuration ---
LLM_MODEL = os.getenv("LLM_MODEL", "llama-3.1-70b-instruct") # Updated model suggestion
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# RAG Configuration
RAG_DOCS_DIR = "rag_docs"
FAISS_INDEX_PATH = "faiss_index"

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure directories exist
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(RAG_DOCS_DIR, exist_ok=True)

def extract_text_from_scanned_pdf(pdf_path):
    """
    Extract text from a scanned PDF, including OCR for image-based content.
    Returns a list of Document objects.
    """
    logger.info(f"Extracting text from scanned PDF: {pdf_path}")
    doc = fitz.open(pdf_path)
    documents = []
    
    for page_num in range(len(doc)):
        page = doc[page_num]
        text = ""
        
        # Extract native text (if any)
        page_text = page.get_text()
        if page_text.strip():
            text += page_text + "\n"
        
        # Extract images and apply OCR
        image_list = page.get_images(full=True)
        for img_index, img in enumerate(image_list):
            xref = img[0]
            base_image = doc.extract_image(xref)
            image_bytes = base_image["image"]
            
            # Convert to PIL Image for OCR
            image = Image.open(io.BytesIO(image_bytes))
            ocr_text = pytesseract.image_to_string(image, lang='eng')
            if ocr_text.strip():
                text += ocr_text + "\n"
        
        if text.strip():
            # Create Document object for each page
            documents.append(Document(
                page_content=text,
                metadata={"source": pdf_path, "type": "pdf", "page": page_num + 1}
            ))
    
    doc.close()
    if not documents:
        logger.warning(f"No text extracted from {pdf_path}. Check OCR setup or PDF content.")
    return documents

def extract_text_from_image(image_path):
    """
    Extract text from a standalone image using OCR.
    Returns a list of Document objects (typically one per image).
    """
    logger.info(f"Extracting text from image: {image_path}")
    try:
        image = Image.open(image_path)
        text = pytesseract.image_to_string(image, lang='eng')
        if text.strip():
            return [Document(
                page_content=text,
                metadata={"source": image_path, "type": "image"}
            )]
        else:
            logger.warning(f"No text extracted from image {image_path}.")
            return []
    except Exception as e:
        logger.warning(f"Failed to process image {image_path}: {e}")
        return []


def setup_rag_pipeline():
    """Initializes the RAG pipeline by creating or loading a FAISS vector store with universal ingestion."""
    logger.info("--- Setting up RAG pipeline with universal ingestion ---")
    
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

    if os.path.exists(FAISS_INDEX_PATH):
        logger.info(f"--- Loading existing FAISS index from '{FAISS_INDEX_PATH}' ---")
        db = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
    else:
        logger.info("--- No existing FAISS index found. Building a new one. ---")
        
        # Expanded loaders for universal support
        loaders = {
            "**/*.pdf": PyPDFLoader,  # For text-based PDFs
            "**/*.md": TextLoader,
            "**/*.txt": TextLoader,
            "**/*.docx": Docx2txtLoader,
            "**/*.csv": CSVLoader,
            "**/*.html": BSHTMLLoader
        }
        documents = []
        
        # Load standard documents
        for glob, loader_cls in loaders.items():
            try:
                loader = DirectoryLoader(RAG_DOCS_DIR, glob=glob, loader_cls=loader_cls, show_progress=True, use_multithreading=True, silent_errors=True)
                loaded_docs = loader.load()
                documents.extend(loaded_docs)
            except Exception as e:
                logger.warning(f"Could not load files with pattern {glob} using {loader_cls.__name__}. Error: {e}")

        # Handle scanned PDFs with OCR (if PyPDFLoader didn't extract text)
        pdf_files = [os.path.join(root, f) for root, _, files in os.walk(RAG_DOCS_DIR) for f in files if f.endswith(".pdf")]
        for pdf_path in pdf_files:
            try:
                # Try PyPDFLoader first
                loader = PyPDFLoader(pdf_path)
                docs = loader.load()
                if any(doc.page_content.strip() for doc in docs):
                    documents.extend(docs)
                else:
                    # Fallback to OCR
                    logger.info(f"No text extracted with PyPDFLoader for {pdf_path}. Attempting OCR.")
                    ocr_docs = extract_text_from_scanned_pdf(pdf_path)
                    documents.extend(ocr_docs)
            except Exception as e:
                logger.warning(f"Could not process PDF {pdf_path}. Error: {e}")

        # Handle standalone images with OCR
        image_extensions = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
        image_files = [os.path.join(root, f) for root, _, files in os.walk(RAG_DOCS_DIR) for f in files if f.lower().endswith(image_extensions)]
        for image_path in image_files:
            ocr_docs = extract_text_from_image(image_path)
            documents.extend(ocr_docs)

        if not documents:
            logger.warning(f"--- No supported documents found in '{RAG_DOCS_DIR}'. Attempting to auto-download key resources ---")
            # Auto-download OWASP Top 10 2021 PDF
            owasp_url = "https://owasp.org/Top10/assets/PDF/OWASP-Top-10-2021.pdf"
            try:
                response = requests.get(owasp_url)
                response.raise_for_status()
                owasp_path = os.path.join(RAG_DOCS_DIR, "owasp_top10_2021.pdf")
                with open(owasp_path, 'wb') as f:
                    f.write(response.content)
                logger.info(f"--- Downloaded OWASP Top 10 2021 PDF to '{owasp_path}' ---")
                # Try PyPDFLoader for OWASP PDF (it's text-based)
                pdf_loader = PyPDFLoader(owasp_path)
                documents.extend(pdf_loader.load())
            except Exception as e:
                logger.error(f"--- Failed to auto-download OWASP PDF: {e} ---")
                raise ValueError(f"No documents available for RAG. Please add files to '{RAG_DOCS_DIR}'.")
        
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
        docs = text_splitter.split_documents(documents)
        
        logger.info(f"--- Creating FAISS index from {len(docs)} document chunks. This may take a moment... ---")
        db = FAISS.from_documents(docs, embeddings)
        db.save_local(FAISS_INDEX_PATH)
        logger.info(f"--- FAISS index created and saved to '{FAISS_INDEX_PATH}' ---")
        
    return db

# --- Initialize RAG and OpenAI Client ---
try:
    rag_db = setup_rag_pipeline()
    client = OpenAI(
        base_url="https://api.scaleway.ai/4a8fd76b-8606-46e6-afe6-617ce8eeb948/v1",
        api_key=os.getenv("SCW_SECRET_KEY")
    )
    logger.info("--- OpenAI client initialized successfully ---")
except Exception as e:
    logger.error(f"--- Failed to initialize services: {e} ---")
    raise

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str
    stride_category: str
    threat_description: str
    mitigation_suggestion: str
    impact: str
    likelihood: str
    references: list[str]
    risk_score: str

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing ---
SAMPLE_DFD = {
    "external_entities": [{"name": "User"}],
    "processes": [{"name": "Web Application"}, {"name": "Authentication Service"}],
    "data_stores": [{"name": "User Database"}],
    "data_flows": [
        {"source": "User", "destination": "Web Application", "data_description": "Login Credentials", "protocol": "HTTPS"},
        {"source": "Web Application", "destination": "User Database", "data_description": "Query User Data", "protocol": "SQL"}
    ],
    "trust_boundaries": [{"name": "Internet to DMZ"}]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:
        logger.warning(f"DFD file at '{DFD_INPUT_PATH}' is empty. Using sample DFD for demonstration.")
        dfd_data = SAMPLE_DFD
except FileNotFoundError:
    logger.warning(f"DFD file not found at '{DFD_INPUT_PATH}'. Using sample DFD for demonstration.")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError as e:
    logger.error(f"FATAL: Error decoding JSON from '{DFD_INPUT_PATH}': {e}")
    exit(1)
except Exception as e:
    logger.error(f"FATAL: Error loading DFD: {e}")
    exit(1)

# --- Enhanced Prompting Strategy ---

# **FIX 1: Define STRIDE categories explicitly to ensure systematic coverage.**
# Parametrized: Load from config or file if needed
STRIDE_DEFINITIONS = {
    "S": ("Spoofing", "Illegitimately accessing systems or data by impersonating a user, process, or component."),
    "T": ("Tampering", "Unauthorized modification of data, either in transit or at rest."),
    "R": ("Repudiation", "A user or system denying that they performed an action, often due to a lack of sufficient proof (e.g., logs)."),
    "I": ("Information Disclosure", "Exposing sensitive information to unauthorized individuals."),
    "D": ("Denial of Service", "Preventing legitimate users from accessing a system or service."),
    "E": ("Elevation of Privilege", "A user or process gaining rights beyond their authorized level.")
}

# Optional: Load custom STRIDE from file
stride_config_path = "stride_config.json"
if os.path.exists(stride_config_path):
    with open(stride_config_path, 'r') as f:
        STRIDE_DEFINITIONS = json.load(f)
    logger.info("--- Loaded custom STRIDE definitions from 'stride_config.json' ---")

# **FIX 2: Create a highly specific prompt template focused on a SINGLE STRIDE category.**
# This prevents generic responses and forces the model to generate relevant, accurate threats.
threat_prompt_template_specific_rag = """
You are a cybersecurity architect specializing in threat modeling using the STRIDE methodology.
Your task is to generate 1-2 specific threats for a given DFD component, focusing ONLY on a single STRIDE category.

**DFD Component to Analyze:**
{component_info}

**STRIDE Category to Focus On:**
- **{stride_category} ({stride_name}):** {stride_definition}

**Security Context from Knowledge Base (for accuracy):**
'''
{rag_context}
'''

**Instructions:**
1.  Generate 1-2 distinct and realistic threats for the component that fall **strictly** under the '{stride_name}' category.
2.  **Be specific.** Relate the threat directly to the component's type and details. For a database, a Spoofing threat is a spoofed connection, not user impersonation. For a data flow, a Tampering threat is a Man-in-the-Middle attack.
3.  Use the provided Security Context to create specific descriptions, **actionable mitigations**, and accurate references (e.g., CWE, OWASP Cheat Sheets). Do not invent references.
4.  Provide a realistic risk assessment (Impact, Likelihood, Score).
5.  Output ONLY a valid JSON object with a single key "threats", containing a list of threat objects. Do not include any other text or commentary.

**JSON Threat Object Schema:**
{{
  "component_name": "string (the name of the component being analyzed)",
  "stride_category": "{stride_category}",
  "threat_description": "string (Specific to the component and STRIDE category)",
  "mitigation_suggestion": "string (Actionable and specific)",
  "impact": "Low, Medium, or High",
  "likelihood": "Low, Medium, or High",
  "references": ["list of strings, e.g., 'OWASP A01:2021', 'CWE-89'"],
  "risk_score": "Critical, High, Medium, or Low"
}}
"""

retry_prompt_addition = " Generate at least one threat if realistically applicable, even if minor."

# Function to calculate risk_score
def calculate_risk_score(impact, likelihood):
    if impact == "High" and likelihood in ["Medium", "High"]:
        return "Critical"
    elif (impact == "High" and likelihood == "Low") or (impact == "Medium" and likelihood == "High"):
        return "High"
    elif (impact == "Medium" and likelihood in ["Medium", "Low"]) or (impact == "Low" and likelihood == "High"):
        return "Medium"
    else:
        return "Low"

# --- Main Invocation Logic ---
logger.info("\n--- Invoking LLM with RAG to systematically generate STRIDE threats ---")
all_threats = []
try:
    components_to_analyze = []
    for key, value in dfd_data.items():
        if isinstance(value, list) and value:
            for item in value:
                # Ensure component has a name for better identification
                if isinstance(item, dict) and item.get("name"):
                    components_to_analyze.append({"type": key, "details": item})
                elif isinstance(item, dict): # Fallback for components without a 'name' field
                    components_to_analyze.append({"type": key, "details": item})


    # **FIX 3: Iterate through each component AND each STRIDE category.**
    # This loop structure ensures every category is considered for every component.
    for component in components_to_analyze:
        component_str = json.dumps(component)
        component_name = component.get("details", {}).get("name", component_str)
        logger.info(f"\n--- Analyzing component: {component_name} ---")

        retrieved_docs = rag_db.similarity_search(component_str, k=5)  # Increased to 5 for broader context
        rag_context = "\n---\n".join([doc.page_content for doc in retrieved_docs])
        logger.info("--- Retrieved RAG context for component ---")

        for cat_letter, (cat_name, cat_def) in STRIDE_DEFINITIONS.items():
            logger.info(f"--- Generating threats for STRIDE category: {cat_name} ---")
            
            prompt = threat_prompt_template_specific_rag.format(
                component_info=component_str,
                rag_context=rag_context,
                stride_category=cat_letter,
                stride_name=cat_name,
                stride_definition=cat_def
            )
            
            retry_count = 0
            max_retries = 1  # Retry once if no threats
            threats = []
            while retry_count <= max_retries and not threats:
                try:
                    if retry_count > 0:
                        prompt += retry_prompt_addition  # Add retry instruction
                    
                    response = client.chat.completions.create(
                        model=LLM_MODEL,
                        messages=[{"role": "user", "content": prompt}],
                        response_format={"type": "json_object"},
                        max_tokens=2048,
                        temperature=0.4 # Slightly lower temp for more focused output
                    )
                    
                    response_content = response.choices[0].message.content
                    generated_data = json.loads(response_content)
                    
                    # Ensure the response is a dict with a 'threats' key which is a list
                    if isinstance(generated_data, dict) and isinstance(generated_data.get("threats"), list):
                        threats = generated_data["threats"]
                        # Add component name if missing from LLM response
                        for threat in threats:
                            if 'component_name' not in threat or not threat['component_name']:
                                threat['component_name'] = component_name
                        logger.info(f"--- Successfully generated {len(threats)} threat(s) for category {cat_name} ---")
                    else:
                        logger.warning(f"--- LLM response for {cat_name} on {component_name} had unexpected structure. ---")
                        logger.debug(f"Raw Response: {response_content}")

                except (json.JSONDecodeError, AttributeError) as e:
                    logger.warning(f"--- Could not parse LLM response for {cat_name} on {component_name}: {e} ---")
                except Exception as e:
                    logger.error(f"--- An API error occurred for {cat_name} on {component_name}: {e} ---")
                
                retry_count += 1

            all_threats.extend(threats)

    # Deduplication: Remove exact duplicates based on description
    unique_threats = []
    seen_descriptions = set()
    for threat in all_threats:
        desc = threat.get('threat_description', '')
        if desc not in seen_descriptions:
            seen_descriptions.add(desc)
            # Recalculate risk_score
            threat['risk_score'] = calculate_risk_score(threat.get('impact', 'Low'), threat.get('likelihood', 'Low'))
            unique_threats.append(threat)
    all_threats = unique_threats
    logger.info(f"--- Deduplicated threats: {len(all_threats)} unique threats remaining ---")

    # --- Final Processing and Validation ---
    risk_order = {"Critical": 4, "High": 3, "Medium": 2, "Low": 1, "Informational": 0}
    all_threats.sort(key=lambda t: risk_order.get(t.get('risk_score', 'Low'), 0), reverse=True)

    final_output = {
        "threats": all_threats,
        "metadata": {
            "timestamp": datetime.now().isoformat(),
            "source_dfd": os.path.basename(DFD_INPUT_PATH),
            "llm_model": LLM_MODEL,
            "rag_index": FAISS_INDEX_PATH
        }
    }

    try:
        validated_output = ThreatsOutput(**final_output)
        logger.info("--- Final JSON output validated successfully against schema ---")
    except ValidationError as ve:
        logger.error(f"--- FINAL JSON VALIDATION FAILED: {ve} ---")
        
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(final_output, f, indent=2)

    logger.info("\n--- LLM RAG Output (Identified Threats) ---")
    # print(json.dumps(final_output, indent=2))
    logger.info(f"\n--- Identified {len(all_threats)} threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during the threat generation process ---")
    logger.error(f"Error: {e}", exc_info=True)

2025-07-29 20:17:51,001 - INFO - --- Setting up RAG pipeline with universal ingestion ---
2025-07-29 20:17:54,009 - INFO - Use pytorch device_name: mps
2025-07-29 20:17:54,010 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
2025-07-29 20:17:57,042 - INFO - --- No existing FAISS index found. Building a new one. ---
 27%|██▋       | 3/11 [00:03<00:10,  1.29s/it]

In [58]:
import os
import json
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.cluster import DBSCAN
import requests
import pandas as pd

# --- Configuration ---
load_dotenv()
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_INPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))
REFINED_THREATS_OUTPUT_PATH = os.getenv("REFINED_THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "refined_threats.json"))
CONTROLS_INPUT_PATH = os.getenv("CONTROLS_INPUT_PATH", os.path.join(INPUT_DIR, "controls.json"))
NVD_API_URL = "https://services.nvd.nist.gov/rest/json/cves/2.0"
CISA_KEV_URL = "https://www.cisa.gov/sites/default/files/feeds/known_exploited_vulnerabilities.json"


# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure directories exist
os.makedirs(INPUT_DIR, exist_ok=True)

# --- Threat Schema ---
class Threat(BaseModel):
    component_name: str = Field(..., description="Standardized name of the component or data flow")
    stride_category: str = Field(..., pattern="^[STRIDE]$", description="STRIDE category (S, T, R, I, D, E)")
    threat_description: str = Field(..., description="Detailed description of the threat")
    mitigation_suggestion: str = Field(..., description="Actionable mitigation specific to the threat")
    impact: str = Field(..., pattern="^(Critical|High|Medium|Low)$", description="Impact level")
    likelihood: str = Field(..., pattern="^(Low|Medium|High)$", description="Likelihood level")
    references: list[str] = Field(..., description="List of references (e.g., CWE, CVE, OWASP)")
    risk_score: str = Field(..., pattern="^(Critical|High|Medium|Low)$", description="Derived risk score")
    residual_risk_score: str = Field(..., pattern="^(Critical|High|Medium|Low)$", description="Risk score post-mitigation")
    exploitability: str = Field(..., pattern="^(Low|Medium|High)$", description="Ease of exploitation")
    mitigation_maturity: str = Field(..., pattern="^(Immature|Mature|Advanced)$", description="Maturity of mitigation controls")
    justification: str = Field(..., description="Rationale for impact and likelihood ratings")
    risk_statement: str = Field(..., description="Business-contextualized risk description")

class RefinedThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Caching for External APIs ---
cisa_kev_cache = None

def get_cisa_kev_catalog():
    """Fetches and caches the CISA KEV catalog."""
    global cisa_kev_cache
    if cisa_kev_cache is None:
        try:
            logger.info("--- Fetching CISA Known Exploited Vulnerabilities (KEV) catalog ---")
            response = requests.get(CISA_KEV_URL, timeout=10)
            response.raise_for_status()
            cisa_kev_cache = {vuln['cveID'] for vuln in response.json().get('vulnerabilities', [])}
            logger.info(f"--- Successfully loaded {len(cisa_kev_cache)} entries from CISA KEV catalog ---")
        except requests.RequestException as e:
            logger.error(f"--- Failed to fetch CISA KEV catalog: {e}. Proceeding without it. ---")
            cisa_kev_cache = set()
    return cisa_kev_cache

# --- Helper Functions ---
def load_dfd_components():
    """Load DFD components to provide full data flow details."""
    logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
    try:
        with open(DFD_INPUT_PATH, 'r') as f:
            dfd_data = json.load(f)
        return dfd_data
    except Exception as e:
        logger.error(f"--- Failed to load DFD components: {e} ---")
        raise

def load_controls():
    """Load client-provided controls to suppress threats."""
    logger.info(f"--- Loading controls from '{CONTROLS_INPUT_PATH}' ---")
    try:
        if os.path.exists(CONTROLS_INPUT_PATH):
            with open(CONTROLS_INPUT_PATH, 'r') as f:
                return json.load(f)
        return {"https_enabled": False, "tls_version": "1.2", "mtls_enabled": False, "secrets_manager": False}
    except Exception as e:
        logger.warning(f"--- Failed to load controls, using defaults: {e} ---")
        return {"https_enabled": False, "tls_version": "1.2", "mtls_enabled": False, "secrets_manager": False}

def check_cve_relevance(cve_id):
    """Check if a CVE is recent (within 5 years) and not actively exploited."""
    kev_catalog = get_cisa_kev_catalog()
    if cve_id in kev_catalog:
        logger.warning(f"--- CVE {cve_id} is in the CISA KEV catalog and should NOT be suppressed. ---")
        return True # It's relevant because it's known to be exploited

    try:
        year = int(cve_id.split('-')[1])
        if year < datetime.now().year - 5:
            logger.info(f"--- CVE {cve_id} is older than 5 years and not in KEV catalog. Considered for suppression. ---")
            return False # Not relevant
        return True # Relevant
    except (ValueError, IndexError):
        logger.warning(f"--- Could not parse year from CVE ID {cve_id}. Assuming it's relevant. ---")
        return True

def calculate_risk_score(impact, likelihood):
    """Calculate risk score based on impact and likelihood."""
    # Convert impact to a numeric value for calculation
    impact_map = {"Critical": 4, "High": 3, "Medium": 2, "Low": 1}
    likelihood_map = {"High": 3, "Medium": 2, "Low": 1}
    
    score = impact_map.get(impact, 1) * likelihood_map.get(likelihood, 1)
    
    if score >= 9:
        return "Critical"
    if score >= 6:
        return "High"
    if score >= 3:
        return "Medium"
    return "Low"

def assess_exploitability(threat, dfd_data):
    """Assess exploitability based on component exposure and protocol."""
    component_name = threat["component_name"]
    # Find the corresponding data flow to check trust boundaries and protocols
    flow = next((f for f in dfd_data.get("data_flows", []) if f"{f['source']} to {f['destination']}" == component_name), None)
    
    if flow and flow.get("source") == "U": # 'U' is the external user
        return "High"
    if flow and "TLS" in flow.get("protocol", "") or "HTTPS" in flow.get("protocol", ""):
        return "Medium"
    return "Low"


def assess_mitigation_maturity(mitigation):
    """Assess maturity of mitigation based on specificity and implementation ease."""
    mitigation_lower = mitigation.lower()
    if "end-to-end encryption" in mitigation_lower or "certificate pinning" in mitigation_lower:
        return "Advanced"
    if "mtls" in mitigation_lower or "waf" in mitigation_lower or "rate limiting" in mitigation_lower or "secrets management" in mitigation_lower:
        return "Mature"
    if "logging" in mitigation_lower or "auditing" in mitigation_lower:
        return "Immature"
    return "Mature"

def standardize_component_name(original_name, valid_flows):
    """Standardize component names to match DFD format."""
    valid_component_names = {f"{flow['source']} to {flow['destination']}" for flow in valid_flows}
    
    # Clean up common variations
    normalized = original_name.replace("Data Flow from ", "").replace(" data flow", "").strip()
    normalized = " ".join(normalized.split()).replace(" to ", " to ")
    
    if normalized in valid_component_names:
        return normalized
    # Fallback for close matches
    for valid_name in valid_component_names:
        if valid_name.lower() in normalized.lower():
            return valid_name
    
    logger.warning(f"--- Component name '{original_name}' not found in DFD; retaining original. ---")
    return original_name

def generate_justification(threat, flow_details):
    """Generate tailored justification for impact and likelihood based on data classification."""
    impact = threat['impact']
    likelihood = threat['likelihood']
    data_classification = flow_details.get("data_classification", "Unclassified") if flow_details else "Unclassified"

    # Justification for Impact
    impact_reason = f"Impact rated {impact} because "
    if data_classification != "Unclassified":
        impact_reason += f"the data flow handles '{data_classification}' data, "
        if data_classification in ["PII", "Confidential", "PHI", "PCI"]:
            impact_reason += "and a breach could lead to regulatory fines and significant reputational damage."
        else:
            impact_reason += "and a breach could cause moderate business disruption."
    elif "DB_P" in threat["component_name"]: # Fallback if no classification
         impact_reason += "of potential exposure of sensitive data in the primary database, leading to severe reputational damage."
    else: # Generic fallback
        impact_reason += {
            "Critical": "of potential for severe business disruption or data breach.",
            "High": "of potential for significant business disruption or data exposure.",
            "Medium": "of moderate disruption or partial data exposure.",
            "Low": "of minimal operational impact."
        }.get(impact, "of minimal operational impact.")

    # Justification for Likelihood
    likelihood_reason = f"Likelihood rated {likelihood} because "
    if flow_details and flow_details.get("source") == 'U':
        likelihood_reason += "the component is internet-facing, increasing the attack surface."
    else:
        likelihood_reason += "the component is internal, reducing direct exposure."

    return f"{impact_reason} {likelihood_reason}"


def generate_risk_statement(threat, flow_details, industry="Generic"):
    """Generate a business-contextualized risk statement using data classification."""
    impact_map = {
        "Critical": "a critical event, potentially causing severe financial loss (e.g., >$1M), major regulatory fines, and long-term reputational damage",
        "High": "significant financial loss (e.g., >$500K), regulatory fines, or reputational damage",
        "Medium": "moderate financial loss (e.g., $50K-$500K) or operational disruption",
        "Low": "minimal financial or operational impact"
    }
    component = threat["component_name"]
    data_classification = flow_details.get("data_classification", "data") if flow_details else "data"

    risk = f"Risk of {threat['threat_description'].lower()} on the '{component}' flow, which handles **{data_classification}**, could lead to {impact_map[threat['impact']]}."
    
    if industry == "Finance" and data_classification == "PCI":
        risk += " This may violate PCI-DSS compliance."
    elif industry == "Healthcare" and data_classification == "PHI":
        risk += " This may violate HIPAA regulations."
    
    # Comment on residual risk based on mitigation maturity
    if threat['residual_risk_score'] < threat['risk_score']:
         risk += f" The proposed mitigation, '{threat['mitigation_suggestion']}', is expected to reduce the risk to '{threat['residual_risk_score']}'."
         
    return risk

def suppress_threats(threats, controls, dfd_data):
    """Suppress or downgrade threats based on implemented controls and CVE relevance."""
    active_threats = []
    for threat in threats:
        suppress = False
        component = threat["component_name"]
        flow = next((f for f in dfd_data.get("data_flows", []) if f"{f['source']} to {f['destination']}" == component), None)
        protocol = flow.get("protocol", "Unknown") if flow else "Unknown"

        # Suppress based on controls
        if controls.get("mtls_enabled") and "spoof" in threat["threat_description"].lower():
            logger.info(f"--- Suppressing '{component}' ({threat['stride_category']}) due to mTLS control. ---")
            suppress = True
        if controls.get("secrets_manager") and "cleartext" in threat["threat_description"].lower():
            logger.info(f"--- Suppressing '{component}' ({threat['stride_category']}) due to secrets management. ---")
            suppress = True
        
        # Suppress based on irrelevant CVEs
        if not suppress:
            relevant_references = []
            for ref in threat.get("references", []):
                if ref.startswith("CVE-") and not check_cve_relevance(ref):
                    logger.info(f"--- Removing outdated/irrelevant CVE '{ref}' from threat '{component}'. ---")
                else:
                    relevant_references.append(ref)
            
            # If all references were irrelevant CVEs, suppress the threat
            if threat.get("references") and not relevant_references:
                 logger.info(f"--- Suppressing threat for '{component}' as its only CVE references were irrelevant. ---")
                 suppress = True
            else:
                threat["references"] = relevant_references

        if not suppress:
            active_threats.append(threat)
            
    return active_threats

def deduplicate_threats(threats, similarity_threshold=0.80):
    """Deduplicate threats using clustering on description and mitigation similarity."""
    if not threats:
        return []
    logger.info("--- Starting threat deduplication ---")
    model = SentenceTransformer('all-mpnet-base-v2')
    
    # Embed a combination of description and mitigation for semantic meaning
    combined_texts = [f"{threat['threat_description']} {threat['mitigation_suggestion']}" for threat in threats]
    embeddings = model.encode(combined_texts, convert_to_tensor=True).cpu().numpy()
    
    # Use DBSCAN for density-based clustering
    clustering = DBSCAN(eps=1 - similarity_threshold, min_samples=1, metric="cosine").fit(embeddings)
    labels = clustering.labels_
    
    # Group threats by cluster, component, and STRIDE for accurate merging
    groups = {}
    for idx, label in enumerate(labels):
        key = (label, threats[idx]["component_name"], threats[idx]["stride_category"])
        if key not in groups:
            groups[key] = []
        groups[key].append(idx)
    
    # Merge threats within each cluster
    deduplicated_threats = []
    for key, indices in groups.items():
        if len(indices) == 1:
            deduplicated_threats.append(threats[indices[0]])
        else:
            cluster_threats = [threats[i] for i in indices]
            # Choose the most detailed description and mitigation from the cluster
            primary_threat = max(cluster_threats, key=lambda t: len(t.get('threat_description', '')))
            primary_threat['mitigation_suggestion'] = max(cluster_threats, key=lambda t: len(t.get('mitigation_suggestion', ''))).get('mitigation_suggestion')

            # Combine all unique references
            combined_references = set()
            for t in cluster_threats:
                combined_references.update(t.get("references", []))
            primary_threat["references"] = sorted(list(combined_references))
            
            deduplicated_threats.append(primary_threat)
            logger.info(f"--- Merged {len(indices)} similar threats for '{primary_threat['component_name']}' ({primary_threat['stride_category']}) ---")
    
    logger.info(f"--- Deduplication reduced {len(threats)} threats to {len(deduplicated_threats)} ---")
    return deduplicated_threats

# --- Main Refinement Logic ---
def refine_threats():
    """Refine threats by deduplicating, standardizing, and enhancing with business risk context."""
    logger.info(f"--- Loading initial threats from '{THREATS_INPUT_PATH}' ---")
    try:
        with open(THREATS_INPUT_PATH, 'r') as f:
            threat_data = json.load(f)
        threats = threat_data.get("threats", [])
        if not threats:
            raise ValueError("Input file contains no threats.")
    except Exception as e:
        logger.error(f"--- Failed to load threats: {e} ---")
        raise

    dfd_data = load_dfd_components()
    controls = load_controls()
    industry = os.getenv("CLIENT_INDUSTRY", "Generic")
    dfd_flows = dfd_data.get("data_flows", [])
    original_threat_count = len(threats)

    # Step 1: Standardize component names to match DFD
    for threat in threats:
        threat["component_name"] = standardize_component_name(threat["component_name"], dfd_flows)

    # Step 2: Suppress threats based on controls and CVE relevance
    threats = suppress_threats(threats, controls, dfd_data)

    # Step 3: Deduplicate remaining threats
    threats = deduplicate_threats(threats)

    # Step 4: Enrich each threat with calculated metadata
    refined_threats = []
    for threat in threats:
        flow_details = next((f for f in dfd_flows if f"{f['source']} to {f['destination']}" == threat["component_name"]), None)
        if flow_details and flow_details.get("data_classification") is None:
            logger.warning(f"--- Data flow '{threat['component_name']}' is missing 'data_classification'. Impact assessment will be generic. ---")

        # Set impact based on data classification if not already high
        if flow_details and flow_details.get("data_classification") in ["PII", "PHI", "PCI", "Confidential"]:
            threat["impact"] = "Critical" if threat["impact"] == "High" else "High"

        # Calculate scores and assessments
        threat["risk_score"] = calculate_risk_score(threat["impact"], threat["likelihood"])
        mitigated_likelihood = "Low" if "logging" not in threat["mitigation_suggestion"].lower() else threat["likelihood"]
        threat["residual_risk_score"] = calculate_risk_score(threat["impact"], mitigated_likelihood)
        threat["exploitability"] = assess_exploitability(threat, dfd_data)
        threat["mitigation_maturity"] = assess_mitigation_maturity(threat["mitigation_suggestion"])
        
        # Generate human-readable statements
        threat["justification"] = generate_justification(threat, flow_details)
        threat["risk_statement"] = generate_risk_statement(threat, flow_details, industry)
        
        refined_threats.append(threat)

    # Step 5: Sort by risk score for prioritization
    risk_order = {"Critical": 4, "High": 3, "Medium": 2, "Low": 1}
    refined_threats.sort(key=lambda t: risk_order.get(t.get("risk_score"), 0), reverse=True)

    # Step 6: Assemble and validate the final output
    final_output = {
        "threats": refined_threats,
        "metadata": {
            "timestamp": datetime.now().isoformat(),
            "source_dfd": os.path.basename(DFD_INPUT_PATH),
            "source_threats": os.path.basename(THREATS_INPUT_PATH),
            "refined_threat_count": len(refined_threats),
            "original_threat_count": original_threat_count,
            "industry_context": industry
        }
    }
    try:
        validated_output = RefinedThreatsOutput(**final_output)
        logger.info("--- Final refined JSON output validated successfully against schema. ---")
    except ValidationError as ve:
        logger.error(f"--- Final JSON validation failed: {ve} ---")
        raise

    # Step 7: Save all outputs
    with open(REFINED_THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(validated_output.model_dump(), f, indent=2)
    logger.info(f"--- Refined {len(refined_threats)} threats saved to '{REFINED_THREATS_OUTPUT_PATH}' ---")

    summary = {
        "total_threats": len(refined_threats),
        "critical_count": sum(1 for t in refined_threats if t["risk_score"] == "Critical"),
        "high_count": sum(1 for t in refined_threats if t["risk_score"] == "High"),
        "medium_count": sum(1 for t in refined_threats if t["risk_score"] == "Medium"),
        "low_count": sum(1 for t in refined_threats if t["risk_score"] == "Low"),
        "prioritization_recommendation": (
            "Remediation should be prioritized based on risk score. Address all 'Critical' and 'High' risk threats within the next development cycle. "
            "Focus on implementing robust, mature controls like mTLS and centralized logging to address systemic weaknesses."
        )
    }
    summary_path = os.path.join(INPUT_DIR, "threat_summary.json")
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    logger.info(f"--- Summary report saved to '{summary_path}' ---")

    pd.DataFrame(final_output["threats"]).to_csv(os.path.join(INPUT_DIR, "threats.csv"), index=False)
    logger.info(f"--- CSV report saved to '{os.path.join(INPUT_DIR, 'threats.csv')}' ---")

# --- Execute ---
if __name__ == "__main__":
    try:
        get_cisa_kev_catalog() # Pre-fetch KEV catalog on startup
        refine_threats()
    except Exception as e:
        logger.error(f"--- Threat refinement process failed with an unrecoverable error: {e} ---", exc_info=True)

2025-07-28 20:33:21,238 - INFO - --- Fetching CISA Known Exploited Vulnerabilities (KEV) catalog ---
2025-07-28 20:33:21,763 - INFO - --- Successfully loaded 1391 entries from CISA KEV catalog ---
2025-07-28 20:33:21,764 - INFO - --- Loading initial threats from './output/identified_threats.json' ---
2025-07-28 20:33:21,765 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-28 20:33:21,766 - INFO - --- Loading controls from './output/controls.json' ---
2025-07-28 20:33:21,766 - INFO - --- CVE CVE-2006-6276 is older than 5 years and not in KEV catalog. Considered for suppression. ---
2025-07-28 20:33:21,767 - INFO - --- Removing outdated/irrelevant CVE 'CVE-2006-6276' from threat 'CDN to LB'. ---
2025-07-28 20:33:21,767 - INFO - --- CVE CVE-2006-6276 is older than 5 years and not in KEV catalog. Considered for suppression. ---
2025-07-28 20:33:21,767 - INFO - --- Removing outdated/irrelevant CVE 'CVE-2006-6276' from threat 'CDN to LB'. ---
2025-07-28 20:

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2025-07-28 20:33:24,925 - INFO - --- Merged 2 similar threats for 'U to CDN' (S) ---
2025-07-28 20:33:24,925 - INFO - --- Merged 2 similar threats for 'U to CDN' (D) ---
2025-07-28 20:33:24,926 - INFO - --- Merged 2 similar threats for 'CDN to LB' (I) ---
2025-07-28 20:33:24,926 - INFO - --- Merged 2 similar threats for 'LB to WS' (R) ---
2025-07-28 20:33:24,926 - INFO - --- Merged 2 similar threats for 'LB to WS' (D) ---
2025-07-28 20:33:24,926 - INFO - --- Merged 2 similar threats for 'LB to WS' (E) ---
2025-07-28 20:33:24,927 - INFO - --- Merged 2 similar threats for 'WS to DB_P' (I) ---
2025-07-28 20:33:24,927 - INFO - --- Merged 2 similar threats for 'WS to MQ' (S) ---
2025-07-28 20:33:24,927 - INFO - --- Merged 2 similar threats for 'WS to MQ' (I) ---
2025-07-28 20:33:24,927 - INFO - --- Merged 2 similar threats for 'WS to MQ' (D) ---
2025-07-28 20:33:24,928 - INFO - --- Merged 2 similar threats for 'WS to MQ' (E) ---
2025-07-28 20:33:24,928 - INFO - --- Merged 2 similar threats 