In [None]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install langchain langchain-community langchain-ollama "unstructured[docx]" pillow nltk python-dotenv pydantic logging

import os
import json
import ssl
import nltk
from urllib.error import URLError
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from datetime import datetime
# --- MODIFICATION: Import partition_docx ---
from unstructured.partition.docx import partition_docx
from dotenv import load_dotenv
from pydantic import BaseModel, ValidationError
import logging

# Load environment variables
load_dotenv()

# --- Configuration ---
# Use environment variables for paths and settings
DOCUMENT_PATH = os.getenv("DOCUMENT_PATH", "./docs/designdoc.docx")
IMAGE_OUTPUT_DIR = os.getenv("IMAGE_OUTPUT_DIR", "./output/images")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./output")
NLTK_DATA_PATH = os.path.join(os.path.expanduser("~"), "nltk_data")

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

llm = ChatOllama(model="WhiteRabbitNeo/WhiteRabbitNeo-V3-7B")

# --- NLTK Data Management ---
# This function checks for necessary NLTK data and downloads it if missing.
# Avoid SSL unverified context; assume secure environment or pre-downloaded data.
def ensure_nltk_data():
    """
    Checks for and downloads required NLTK data securely.
    """
    required_package = 'punkt'

    if NLTK_DATA_PATH not in nltk.data.path:
        nltk.data.path.append(NLTK_DATA_PATH)

    logger.info("--- Verifying NLTK data packages ---")
    try:
        nltk.data.find('tokenizers/punkt')
        logger.info("[✓] NLTK 'punkt' data is available.")
    except LookupError:
        logger.warning("[!] NLTK 'punkt' data not found. Attempting to download...")
        try:
            # Use default SSL context; no unverified workaround
            nltk.download(required_package, download_dir=NLTK_DATA_PATH)
            logger.info(f"[✓] '{required_package}' downloaded successfully.")
        except Exception as e:
            logger.error(f"Failed to download NLTK data: {e}")
            logger.error("Ensure internet access and proper SSL configuration. Alternatively, pre-download NLTK data.")
            raise
    logger.info("--- NLTK setup complete ---")

# --- DFD Components Schema for Validation ---
class DataFlow(BaseModel):
    source: str
    destination: str
    data_description: str

class DFDComponents(BaseModel):
    assets: list[str]
    processes: list[str]
    data_flows: list[DataFlow]
    metadata: dict

# --- Document & Image Loading with partition_docx ---
full_document_text = ""
elements = []

try:
    # Ensure directories exist
    os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    ensure_nltk_data()

    logger.info(f"--- Loading document and extracting images from: {DOCUMENT_PATH} ---")
    # Use partition_docx to extract both text and images
    # Note: For security, consider scanning the file with an antivirus tool before processing in production.
    elements = partition_docx(
        filename=DOCUMENT_PATH,
        extract_images_in_document=True,
        image_output_dir_path=IMAGE_OUTPUT_DIR
    )

    # Reconstruct the document text from elements for the LLM
    # and confirm image extractions.
    for element in elements:
        if hasattr(element, 'text'):
            full_document_text += element.text + "\n\n"
        elif "Image" in str(type(element)):
            if element.metadata.image_path:
                logger.info(f"[✓] Extracted image saved to: {element.metadata.image_path}")

    if not full_document_text.strip():
        raise ValueError("Document processed but no text content was found.")

    logger.info("--- Successfully processed document ---")

except FileNotFoundError:
    logger.error(f"--- FATAL ERROR: Input document not found at '{DOCUMENT_PATH}' ---")
    exit(1)
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while processing the document ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Prompt Engineering ---
prompt_template = """
You are an expert cybersecurity architect specializing in threat modeling, with knowledge of current standards as of 2025, including OWASP Threat Modeling Cheat Sheet and NIST SP 800-53.

Your task is to carefully read the provided system design document and decompose it into core components for a Data Flow Diagram (DFD) Level 0 or 1. Use a Chain-of-Thought approach: 
1. First, identify all mentioned components, data stores, and interactions from the document.
2. Classify them systematically: Assets (data at rest, e.g., databases, files), Processes (data transformation, e.g., servers, services), External Entities (e.g., users, third-party APIs).
3. Map data flows between them, describing the data exchanged and any protocols or security controls mentioned.
4. Identify trust boundaries (e.g., zones like public, DMZ, internal) and any external dependencies.

Output a valid JSON object with the following keys:
- 'external_entities': Array of strings (e.g., users or external systems not under control).
- 'assets': Array of strings (data stores where data rests, e.g., database, cache, log file).
- 'processes': Array of strings (components that act on or transform data, e.g., API, microservice, web server).
- 'data_flows': Array of objects, each with 'source' (from external_entities, assets, or processes), 'destination' (same), 'data_description' (what data is exchanged, e.g., "user credentials"), 'protocol' (e.g., "HTTPS", if mentioned).
- 'trust_boundaries': Array of strings describing zones or boundaries (e.g., "Public Zone to Edge Zone").

Stick strictly to the document content—do not hallucinate or add unmentioned elements. If the document includes diagrams or images (via extracted text), incorporate their descriptions.

System Design Document:
---
{document_text}
---

Generate ONLY the JSON object, with no additional text, explanations, or formatting.
"""

prompt = ChatPromptTemplate.from_template(prompt_template)

# --- Chain Construction with JSON Output Parser ---
output_parser = JsonOutputParser()
chain = prompt | llm | output_parser

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM Chain (Mixtral) to extract DFD components ---")
output_path = os.path.join(OUTPUT_DIR, "dfd_components.json")

try:
    # Parameterize prompt to mitigate injection (though document_text is trusted here)
    response_dict = chain.invoke({"document_text": full_document_text})
    
    # Add metadata
    response_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_document": DOCUMENT_PATH
    }
    
    # Validate the output against schema
    try:
        validated = DFDComponents(**response_dict)
        logger.info("--- JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        raise
    
    logger.info("\n--- LLM Output (Parsed JSON) ---")
    print(json.dumps(response_dict, indent=2))
    
    # Save the dictionary to a JSON file
    with open(output_path, 'w') as f:
        json.dump(response_dict, f, indent=2)
    
    logger.info(f"\n--- DFD components successfully saved to '{output_path}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during chain invocation or parsing ---")
    logger.error(f"Error: {e}")
    logger.error("This may be due to the LLM not returning a well-formed JSON object.")

2025-07-27 07:03:02,577 - INFO - --- Verifying NLTK data packages ---
2025-07-27 07:03:02,578 - INFO - [✓] NLTK 'punkt' data is available.
2025-07-27 07:03:02,578 - INFO - --- NLTK setup complete ---
2025-07-27 07:03:02,578 - INFO - --- Loading document and extracting images from: ./docs/designdoc.docx ---
2025-07-27 07:03:02,598 - INFO - --- Successfully processed document ---
2025-07-27 07:03:02,599 - INFO - 
--- Invoking Local LLM Chain (Mixtral) to extract DFD components ---
2025-07-27 07:03:10,780 - INFO - HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
2025-07-27 07:03:29,077 - INFO - --- JSON output validated successfully ---
2025-07-27 07:03:29,078 - INFO - 
--- LLM Output (Parsed JSON) ---
2025-07-27 07:03:29,079 - INFO - 
--- DFD components successfully saved to './output/dfd_components.json' ---


{
  "assets": [
    "Primary Database",
    "Backup System"
  ],
  "processes": [
    "User",
    "CDN / WAF",
    "Load Balancer",
    "Web Server Farm",
    "Message Queue Server",
    "Worker Service",
    "Admin Portal"
  ],
  "data_flows": [
    {
      "source": "User",
      "destination": "CDN / WAF",
      "data_description": "HTTPS"
    },
    {
      "source": "CDN / WAF",
      "destination": "Load Balancer",
      "data_description": "HTTPS"
    },
    {
      "source": "Load Balancer",
      "destination": "Web Server Farm",
      "data_description": "HTTPS"
    },
    {
      "source": "Web Server Farm",
      "destination": "Primary Database",
      "data_description": "JDBC/ODBC over TLS"
    },
    {
      "source": "Web Server Farm",
      "destination": "Message Queue Server",
      "data_description": "AMQP over TLS"
    },
    {
      "source": "Worker Service",
      "destination": "Message Queue Server",
      "data_description": "AMQP over TLS"
    },
    {
   

In [None]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install langchain langchain-community langchain-ollama python-dotenv pydantic logging

import os
import json
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, ValidationError
import logging

# Load environment variables
load_dotenv()

# --- Configuration ---
# Use environment variables for paths and settings
LLM_MODEL = os.getenv("LLM_MODEL", "mixtral")
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

llm = ChatOllama(model=LLM_MODEL)

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str
    stride_category: str
    threat_description: str
    mitigation_suggestion: str

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.error(f"--- FATAL ERROR: Input file not found at '{DFD_INPUT_PATH}' ---")
    logger.error("Please run the first script (to generate DFD components) before running this one.")
    exit(1)
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty.")
    exit(1)
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Prompt Engineering for Threat Generation ---
threat_prompt_template = """
You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where applicable.
2. Describe threats considering trust boundaries, protocols, and potential attack vectors (e.g., injection, misconfiguration).
3. Suggest mitigations with references to standards (e.g., "NIST AC-6 for least privilege").
4. Assess impact (Low/Medium/High based on potential damage) and likelihood (Low/Medium/High based on exploitability).

For each threat, include:
- 'component_name': Affected asset, process, data flow, or entity.
- 'stride_category': One STRIDE category.
- 'threat_description': Clear, specific description (e.g., "Attacker intercepts unencrypted data in transit leading to disclosure").
- 'mitigation_suggestion': Practical, actionable mitigation (e.g., "Implement TLS 1.3 with certificate pinning").
- 'impact': Low/Medium/High.
- 'likelihood': Low/Medium/High.
- 'references': Array of strings (e.g., ["OWASP A01:2021", "NIST SI-2"]).

DFD Components:
---
{dfd_json}
---

Generate a JSON object with a key 'threats' (array of threat objects). Output ONLY the JSON, with no additional commentary or formatting.
"""

threat_prompt = ChatPromptTemplate.from_template(threat_prompt_template)

# --- Chain Construction with JSON Output Parser ---
threat_parser = JsonOutputParser()
threat_chain = threat_prompt | llm | threat_parser

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM Chain (Mixtral) to generate STRIDE threats ---")
try:
    # Convert the loaded DFD dictionary back to a JSON string for the prompt
    dfd_json_string = json.dumps(dfd_data, indent=2)

    # Invoke the threat analysis chain
    threats_dict = threat_chain.invoke({"dfd_json": dfd_json_string})
    
    # Add metadata
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }
    
    # Validate the output against schema
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        raise
    
    # Ensure the output directory exists
    os.makedirs(INPUT_DIR, exist_ok=True)
    
    # Save the threats to a new file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)
        
    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to the LLM not returning a well-formed JSON object or an issue with the input data.")

2025-07-27 07:03:29,127 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-27 07:03:29,127 - INFO - --- DFD components loaded successfully ---
2025-07-27 07:03:29,128 - INFO - 
--- Invoking Local LLM Chain (Mixtral) to generate STRIDE threats ---
2025-07-27 07:03:45,117 - INFO - HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
2025-07-27 07:04:09,795 - INFO - --- JSON output validated successfully ---
2025-07-27 07:04:09,799 - INFO - 
--- LLM Output (Identified Threats) ---
2025-07-27 07:04:09,800 - INFO - 
--- Identified threats successfully saved to './output/identified_threats.json' ---


{
  "threats": [
    {
      "component_name": "User",
      "stride_category": "Spoofing",
      "threat_description": "An attacker may impersonate a legitimate user to gain unauthorized access.",
      "mitigation_suggestion": "Implement multi-factor authentication and enforce strong password policies."
    },
    {
      "component_name": "CDN / WAF",
      "stride_category": "Information Disclosure",
      "threat_description": "Sensitive data may be exposed due to misconfigured or weak security settings in the CDN/WAF.",
      "mitigation_suggestion": "Regularly audit and configure CDN/WAF settings according to best practices."
    },
    {
      "component_name": "Load Balancer",
      "stride_category": "Denial of Service",
      "threat_description": "A DDoS attack can overload the load balancer, making services unavailable.",
      "mitigation_suggestion": "Implement a robust DDoS mitigation strategy, including traffic filtering and rate limiting."
    },
    {
      "componen