In [17]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install langchain langchain-community langchain-ollama python-dotenv pydantic logging instructor

import os
import json
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
import instructor
from ollama import Client  # Added for raw response debugging

# Load environment variables
load_dotenv()

# --- Configuration ---
# Use environment variables for paths and settings
LLM_MODEL = os.getenv("LLM_MODEL", "llama3:8b")  # Changed default to a more reliable model for testing
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists early
os.makedirs(INPUT_DIR, exist_ok=True)

# Initialize LLM with Instructor for schema enforcement
llm = instructor.from_provider(f"ollama/{LLM_MODEL}", mode=instructor.Mode.JSON_SCHEMA)

# Added: Raw Ollama client for debugging
ollama_client = Client()

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(description="Affected asset, process, data flow, or entity.")
    stride_category: str = Field(description="One STRIDE category: Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege.")
    threat_description: str = Field(description="Clear, specific description of the threat.")
    mitigation_suggestion: str = Field(description="Practical, actionable mitigation.")
    impact: str = Field(description="Low/Medium/High based on potential damage.")
    likelihood: str = Field(description="Low/Medium/High based on exploitability.")
    references: list[str] = Field(description="Array of standard references (e.g., ['OWASP A01:2021', 'NIST SI-2']).")

class Threats(BaseModel):
    threats: list[Threat]

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing (if input is empty) ---
SAMPLE_DFD = {
    "external_entities": ["User", "Attacker"],
    "processes": ["Web Application", "Authentication Service"],
    "data_stores": ["User Database"],
    "data_flows": [
        {
            "from": "User",
            "to": "Web Application",
            "data": "Login Credentials",
            "protocol": "HTTP"
        },
        {
            "from": "Web Application",
            "to": "User Database",
            "data": "Query User Data",
            "protocol": "SQL"
        }
    ],
    "trust_boundaries": ["Internet to DMZ", "DMZ to Internal Network"]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:  # Added: Check for empty data
        logger.warning("--- DFD data is empty. Using sample DFD for testing ---")
        dfd_data = SAMPLE_DFD
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.warning(f"--- Input file not found at '{DFD_INPUT_PATH}'. Using sample DFD for testing ---")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty. Using sample DFD for testing.")
    dfd_data = SAMPLE_DFD
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Prompt Engineering for Threat Generation ---
threat_prompt_template = """
You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where applicable.
2. Describe threats considering trust boundaries, protocols, and potential attack vectors (e.g., injection, misconfiguration).
3. Suggest mitigations with references to standards (e.g., "NIST AC-6 for least privilege").
4. Assess impact (Low/Medium/High based on potential damage) and likelihood (Low/Medium/High based on exploitability).

For each threat, include:
- 'component_name': Affected asset, process, data flow, or entity.
- 'stride_category': One STRIDE category.
- 'threat_description': Clear, specific description (e.g., "Attacker intercepts unencrypted data in transit leading to disclosure").
- 'mitigation_suggestion': Practical, actionable mitigation (e.g., "Implement TLS 1.3 with certificate pinning").
- 'impact': Low/Medium/High.
- 'likelihood': Low/Medium/High.
- 'references': Array of strings (e.g., ["OWASP A01:2021", "NIST SI-2"]).

DFD Components:
---
{dfd_json}
---

Generate a JSON object with a key 'threats' (array of threat objects). Output ONLY the JSON, with no additional commentary or formatting.
"""

threat_prompt = ChatPromptTemplate.from_template(threat_prompt_template)

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM to generate STRIDE threats ---")
try:
    # Convert the loaded DFD dictionary back to a JSON string for the prompt
    dfd_json_string = json.dumps(dfd_data, indent=2)

    # Generate messages from the prompt template
    messages = threat_prompt.format_messages(dfd_json=dfd_json_string)

    # Added: Log the prompt for debugging
    logger.info(f"--- Prompt sent to LLM ---\n{messages[0].content}")

    # Added: Call raw Ollama for response debugging (before Instructor)
    raw_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": messages[0].content}])
    logger.info(f"--- Raw LLM Response ---\n{raw_response['message']['content']}")

    # Invoke the LLM with Instructor for structured output
    threats_obj = llm.chat.completions.create(
        messages=[{"role": "user", "content": messages[0].content}],
        response_model=Threats,
        max_retries=5  # Increased for better handling
    )

    threats_dict = threats_obj.model_dump()
    
    # Add metadata
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }
    
    # Validate the output against schema (Instructor already enforces, but double-check)
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        raise
    
    # Save the threats to a new file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)
        
    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to the LLM not returning a well-formed JSON object or an issue with the input data.")

2025-07-27 19:49:51,081 - INFO - Initializing ollama provider with model llama3:8b
2025-07-27 19:49:51,103 - INFO - Client initialized
2025-07-27 19:49:51,118 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-27 19:49:51,119 - INFO - --- DFD components loaded successfully ---
2025-07-27 19:49:51,119 - INFO - 
--- Invoking Local LLM to generate STRIDE threats ---
2025-07-27 19:49:51,120 - INFO - --- Prompt sent to LLM ---

You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where 

{
  "threats": [
    {
      "component_name": "U",
      "stride_category": "Information Disclosure",
      "threat_description": "Unencrypted data transmitted from U to CDN potentially disclosed",
      "mitigation_suggestion": "Implement end-to-end encryption; Use HTTPS protocol",
      "impact": "Medium",
      "likelihood": "High",
      "references": [
        "OWASP A01:2021",
        "NIST SI-2"
      ]
    },
    {
      "component_name": "CDN",
      "stride_category": "Tampering",
      "threat_description": "Malicious actor intercepts and modifies data in transit from CDN to LB",
      "mitigation_suggestion": "Implement integrity checking; Use digital signatures",
      "impact": "High",
      "likelihood": "Medium",
      "references": [
        "OWASP A03:2021",
        "MITRE CA-8"
      ]
    },
    {
      "component_name": "LB",
      "stride_category": "Elevation of Privilege",
      "threat_description": "Unprivileged actor gains elevated privileges on LB, potentia

In [47]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install instructor openai pydantic logging python-dotenv

import os
import json
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
from ollama import Client  # For raw debugging
import instructor
from openai import OpenAI  # Wrapper for Ollama

# Load environment variables
load_dotenv()

# --- Configuration ---
LLM_MODEL = os.getenv("LLM_MODEL", "llama3-70b-m3max:latest")
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists
os.makedirs(INPUT_DIR, exist_ok=True)

# Initialize Ollama client with Instructor (using OpenAI wrapper)
try:
    client = instructor.from_openai(
        OpenAI(
            base_url="http://aioverlord:11434/v1",
            api_key="ollama",  # Required but unused
        ),
        mode=instructor.Mode.JSON,  # Use JSON mode for Ollama
    )
    logger.info("--- Ollama client initialized successfully on port 11434 ---")
except Exception as e:
    logger.error(f"--- Failed to initialize Ollama client on port 11434: {e} ---")
    raise

# Raw Ollama client for debugging
try:
    ollama_client = Client(host="http://aioverlord:11434")
    logger.info("--- Raw Ollama client initialized successfully on port 11434 ---")
except Exception as e:
    logger.error(f"--- Failed to initialize raw Ollama client on port 11434: {e} ---")
    raise

# Health check: Ping the server by listing models (or any lightweight endpoint)
try:
    models_response = ollama_client.list()
    logger.info(f"--- Ollama server health check successful on port 11434. Available models: {models_response.get('models', 'None listed')} ---")
except Exception as e:
    logger.error(f"--- Ollama server health check failed on port 11434: {e} ---")
    raise

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(description="Affected asset, process, data flow, or entity.")
    stride_category: str = Field(description="One letter: S, T, R, I, D, or E.")  # Aligned with prompt
    threat_description: str = Field(description="Clear, specific description of the threat, including attack vectors and consequences.")
    mitigation_suggestion: str = Field(description="Practical, actionable mitigation.")
    impact: str = Field(description="Low, Medium, or High based on potential damage.")
    likelihood: str = Field(description="Low, Medium, or High based on exploitability.")
    references: list[str] = Field(description="Array of 1-3 valid references (e.g., ['OWASP A01:2021', 'NIST SP 800-53 SC-28', 'CWE-89']).")
    risk_score: str = Field(description="Critical, High, Medium, or Low (calculated from impact and likelihood).")

class Threats(BaseModel):
    threats: list[Threat]

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing (if input is empty) ---
SAMPLE_DFD = {
    "external_entities": ["User", "Attacker"],
    "processes": ["Web Application", "Authentication Service"],
    "data_stores": ["User Database"],
    "data_flows": [
        {
            "source": "User",
            "destination": "Web Application",
            "data_description": "Login Credentials",
            "protocol": "HTTP"
        },
        {
            "source": "Web Application",
            "destination": "User Database",
            "data_description": "Query User Data",
            "protocol": "SQL"
        }
    ],
    "trust_boundaries": ["Internet to DMZ", "DMZ to Internal Network"]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:
        logger.warning("--- DFD data is empty. Using sample DFD for testing ---")
        dfd_data = SAMPLE_DFD
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.warning(f"--- Input file not found at '{DFD_INPUT_PATH}'. Using sample DFD for testing ---")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty. Using sample DFD for testing.")
    dfd_data = SAMPLE_DFD
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Improved Prompt Engineering for Threat Generation ---
threat_prompt_template = """
You are a cybersecurity architect specializing in threat modeling using the STRIDE framework. Your task is to generate a complete list of relevant threats based on the provided DFD components in JSON format.

STRIDE categories must be exactly one of: 
- S: Spoofing
- T: Tampering
- R: Repudiation
- I: Information Disclosure
- D: Denial of Service
- E: Elevation of Privilege

Do not use any other categories or variations. Generate threats covering ALL components, including:
- External entities (e.g., "User")
- Processes (e.g., "Web Application")
- Assets (data stores, e.g., "User Database")
- Data flows (label as "Source to Destination", e.g., "User to Web Application")
- Trust boundaries (e.g., "Internet to DMZ")

For each component, explicitly consider all six STRIDE categories (S, T, R, I, D, E) and generate threats for each where realistically applicable, aiming for at least 5-6 threats per component to ensure broad and comprehensive coverage. Justify omissions internally if a category truly does not apply, but prioritize inclusion to avoid gaps.

For each threat, output strictly in this JSON structure:
{{
  "component_name": "Affected component (e.g., 'User to Web Application' or 'User Database')",
  "stride_category": "One letter: S, T, R, I, D, or E",
  "threat_description": "Clear, specific description tied to the component, including attack vectors (e.g., 'via MITM on weak TLS leading to PII exposure and regulatory fines'). Reference protocols from the DFD (e.g., HTTPS, AMQP).",
  "mitigation_suggestion": "Practical, actionable mitigation (e.g., 'Implement HTTPS with certificate pinning and HSTS').",
  "impact": "Low, Medium, or High",
  "likelihood": "Low, Medium, or High",
  "references": ["Array of 1-3 valid strings (e.g., 'OWASP A01:2021', 'NIST SP 800-53 SC-28', 'CWE-89')"],
  "risk_score": "Critical, High, Medium, or Low (calculate as: Critical if Impact=High and Likelihood=Medium/High; High if Impact=High and Likelihood=Low or Impact=Medium and Likelihood=High; Medium if Impact=Medium and Likelihood=Medium/Low or Impact=Low and Likelihood=High; Low otherwise)"
}}

Think step-by-step internally before generating:
1. Parse the DFD JSON: Explicitly list all external_entities, processes, assets (data_stores), data_flows (as 'source to destination'), and trust_boundaries. Ensure every item is addressed.
2. For each listed component, brainstorm applicable STRIDE threats: Consider all six categories, identifying typical risks (e.g., data flows vulnerable to T and I due to transit; processes to D and E from overload or vulns; data stores to I at rest). Generate at least one threat per category unless impossible, to achieve 5-6 per component.
3. Make threats realistic and specific: Include attack vectors (e.g., SQL injection for tampering), consequences (e.g., data breach leading to fines), and reference DFD protocols/architecture.
4. Avoid duplicates: Ensure no identical threats across components; vary descriptions even for similar risks.
5. Assign impact/likelihood with justification: Base on exposure (e.g., public-facing = higher likelihood) and calculate risk_score accurately.
6. Use only valid references from: OWASP Top 10 2021 (A01-A10 only, e.g., A03:2021 for Injection), NIST SP 800-series (e.g., 800-53, 800-63B, 800-52), CWE (e.g., CWE-89), MITRE ATT&CK (e.g., T1071). Limit to 1-3 per threat; do not invent invalid ones like A11.

Examples of good threats:
{{
  "component_name": "User to Web Application",
  "stride_category": "S",
  "threat_description": "Attacker spoofs user identity via phishing to send fake login credentials over HTTP, leading to account takeover and unauthorized access.",
  "mitigation_suggestion": "Implement strong authentication such as OAuth 2.0 with JWT tokens and MFA.",
  "impact": "High",
  "likelihood": "Medium",
  "references": ["OWASP A05:2021", "NIST SP 800-63B", "CWE-287"],
  "risk_score": "Critical"
}},
{{
  "component_name": "User Database",
  "stride_category": "I",
  "threat_description": "Unauthorized access via SQL injection leading to data leakage of sensitive PII, resulting in privacy violations and fines.",
  "mitigation_suggestion": "Encrypt data at rest using AES-256, implement parameterized queries, and enforce least-privilege access controls.",
  "impact": "High",
  "likelihood": "Medium",
  "references": ["OWASP A04:2021", "NIST SP 800-53 SC-28", "CWE-200"],
  "risk_score": "Critical"
}}

Negative examples to avoid:
- Invalid reference like "OWASP A11:2021" (Top 10 only has A01-A10).
- Generic description like "Unauthorized access to database leading to data leakage" (add vectors/consequences).
- Missing STRIDE coverage without internal justification.
- Fewer than 5 threats per component, leading to gaps.

DFD Components JSON:
{dfd_json}

Output ONLY a JSON object with:
- "threats": [array of threat objects, sorted by risk_score descending (Critical first, then High, Medium, Low)]

Do not include metadata or any other keys. Output ONLY the JSON, with no additional text, commentary, reasoning, or formatting.

"""

# --- Validation Prompt Template ---
validation_prompt_template = """
You are a JSON validator for threat modeling outputs. Your task is to validate and correct the following threats JSON to ensure:
- It is valid JSON.
- Each threat matches the required schema: component_name, stride_category (S/T/R/I/D/E), threat_description, mitigation_suggestion, impact (Low/Medium/High), likelihood (Low/Medium/High), references (list of 1-3 strings), risk_score (Critical/High/Medium/Low).
- No duplicates.
- Risk scores are correctly calculated based on impact and likelihood.
- References are valid (from OWASP, NIST, CWE, MITRE).
- Threats are sorted by risk_score descending (Critical > High > Medium > Low).

If invalid or improvable, correct it. If valid, return the original.

Input Threats JSON:
{threats_json}

Output ONLY the corrected JSON object with key "threats": [array of threat objects]. No other text or explanations.
"""

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM to generate STRIDE threats ---")
try:
    # Prepare the generation prompt with DFD JSON
    dfd_json_string = json.dumps(dfd_data, indent=2)
    gen_prompt = threat_prompt_template.format(dfd_json=dfd_json_string)

    # Log the prompt for debugging
    logger.info(f"--- Generation Prompt sent to LLM ---\n{gen_prompt}")

    # Call raw Ollama for response debugging
    raw_gen_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": gen_prompt}])
    logger.info(f"--- Raw LLM Generation Response ---\n{raw_gen_response['message']['content']}")

    # Invoke with Instructor for structured output
    threats_obj = client.chat.completions.create(
        model=LLM_MODEL,
        messages=[{"role": "user", "content": gen_prompt}],
        response_model=Threats,
        max_retries=5  # Retries for validation failures
    )
    logger.info("--- Structured threat generation successful ---")

    threats_dict = threats_obj.model_dump()

    # Prepare validation prompt
    threats_json_string = json.dumps(threats_dict, indent=2)
    val_prompt = validation_prompt_template.format(threats_json=threats_json_string)

    # Log the validation prompt
    logger.info(f"--- Validation Prompt sent to LLM ---\n{val_prompt}")

    # Call raw Ollama for validation response
    raw_val_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": val_prompt}])
    logger.info(f"--- Raw LLM Validation Response ---\n{raw_val_response['message']['content']}")

    # Parse validated response
    try:
        validated_threats = json.loads(raw_val_response['message']['content'])
        threats_dict = {"threats": validated_threats.get("threats", threats_dict["threats"])}
        logger.info("--- Threats validated and parsed successfully ---")
    except json.JSONDecodeError:
        logger.warning("--- Validation parsing failed; using original threats ---")

    # Add metadata manually
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }

    # Validate against full schema
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully against schema ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        # Fallback: Try manual parsing of raw generation response
        try:
            raw_json = json.loads(raw_gen_response['message']['content'])
            threats_dict = {"threats": raw_json.get("threats", []), "metadata": threats_dict["metadata"]}
            validated = ThreatsOutput(**threats_dict)
            logger.info("--- Fallback manual parsing succeeded ---")
        except Exception as pe:
            logger.error(f"--- Manual parsing failed: {pe} ---")
            raise

    # Save to file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)

    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to Ollama not enforcing JSON strictly or connection issues on port 5080. Check raw response and adjust prompt.")

2025-07-28 13:17:02,805 - INFO - --- Ollama client initialized successfully on port 11434 ---
2025-07-28 13:17:02,821 - INFO - --- Raw Ollama client initialized successfully on port 11434 ---
2025-07-28 13:17:02,861 - INFO - HTTP Request: GET http://aioverlord:11434/api/tags "HTTP/1.1 200 OK"
2025-07-28 13:17:02,861 - INFO - --- Ollama server health check successful on port 11434. Available models: [Model(model='llama3-70b-m3max:latest', modified_at=datetime.datetime(2025, 7, 28, 11, 7, 1, 64630, tzinfo=TzInfo(+02:00)), digest='11a9abedc8a182544453e5f23f6f425ed0d551200cdaba1aef62e626982d5c05', size=39969745456, details=ModelDetails(parent_model='', format='gguf', family='llama', families=['llama'], parameter_size='70.6B', quantization_level='Q4_0')), Model(model='llama3:70b-instruct', modified_at=datetime.datetime(2025, 7, 28, 10, 7, 48, 277784, tzinfo=TzInfo(+02:00)), digest='786f3184aec0e907952488b865362bdaa38180739a9881a8190d85bad8cab893', size=39969745349, details=ModelDetails(pare

KeyboardInterrupt: 

In [None]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install openai pydantic logging python-dotenv langchain langchain_community langchain_huggingface faiss-cpu pypdf sentence-transformers

import os
import json
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
from openai import OpenAI

# RAG specific imports
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Load environment variables
load_dotenv()

# --- Configuration ---
LLM_MODEL = os.getenv("LLM_MODEL", "llama-3.1-70b-instruct") # Updated model suggestion
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# RAG Configuration
RAG_DOCS_DIR = "rag_docs"
FAISS_INDEX_PATH = "faiss_index"

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure directories exist
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(RAG_DOCS_DIR, exist_ok=True)


def setup_rag_pipeline():
    """Initializes the RAG pipeline by creating or loading a FAISS vector store."""
    logger.info("--- Setting up RAG pipeline ---")
    
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

    if os.path.exists(FAISS_INDEX_PATH):
        logger.info(f"--- Loading existing FAISS index from '{FAISS_INDEX_PATH}' ---")
        db = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
    else:
        logger.info("--- No existing FAISS index found. Building a new one. ---")
        
        loaders = {
            "**/*.pdf": PyPDFLoader,
            "**/*.md": TextLoader,
            "**/*.txt": TextLoader
        }
        documents = []
        for glob, loader_cls in loaders.items():
            try:
                loader = DirectoryLoader(RAG_DOCS_DIR, glob=glob, loader_cls=loader_cls, show_progress=True, use_multithreading=True, silent_errors=True)
                documents.extend(loader.load())
            except Exception as e:
                logger.warning(f"Could not load files with pattern {glob} using {loader_cls.__name__}. Error: {e}")

        if not documents:
            logger.error(f"--- FATAL: No supported documents (e.g., .pdf, .md, .txt) were found in '{RAG_DOCS_DIR}'. ---")
            logger.error("--- Please add your security documents (e.g., OWASP PDFs) to this directory. ---")
            raise ValueError(f"No supported documents found in '{RAG_DOCS_DIR}'.")
        
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
        docs = text_splitter.split_documents(documents)
        
        logger.info(f"--- Creating FAISS index from {len(docs)} document chunks. This may take a moment... ---")
        db = FAISS.from_documents(docs, embeddings)
        db.save_local(FAISS_INDEX_PATH)
        logger.info(f"--- FAISS index created and saved to '{FAISS_INDEX_PATH}' ---")
        
    return db

# --- Initialize RAG and OpenAI Client ---
try:
    rag_db = setup_rag_pipeline()
    client = OpenAI(
        base_url="https://api.scaleway.ai/4a8fd76b-8606-46e6-afe6-617ce8eeb948/v1",
        api_key=os.getenv("SCW_SECRET_KEY")
    )
    logger.info("--- OpenAI client initialized successfully ---")
except Exception as e:
    logger.error(f"--- Failed to initialize services: {e} ---")
    raise

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str
    stride_category: str
    threat_description: str
    mitigation_suggestion: str
    impact: str
    likelihood: str
    references: list[str]
    risk_score: str

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing ---
SAMPLE_DFD = {
    "external_entities": [{"name": "User"}],
    "processes": [{"name": "Web Application"}, {"name": "Authentication Service"}],
    "data_stores": [{"name": "User Database"}],
    "data_flows": [
        {"source": "User", "destination": "Web Application", "data_description": "Login Credentials", "protocol": "HTTPS"},
        {"source": "Web Application", "destination": "User Database", "data_description": "Query User Data", "protocol": "SQL"}
    ],
    "trust_boundaries": [{"name": "Internet to DMZ"}]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:
        logger.warning(f"DFD file at '{DFD_INPUT_PATH}' is empty. Using sample DFD for demonstration.")
        dfd_data = SAMPLE_DFD
except FileNotFoundError:
    logger.warning(f"DFD file not found at '{DFD_INPUT_PATH}'. Using sample DFD for demonstration.")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError as e:
    logger.error(f"FATAL: Error decoding JSON from '{DFD_INPUT_PATH}': {e}")
    exit(1)
except Exception as e:
    logger.error(f"FATAL: Error loading DFD: {e}")
    exit(1)

# --- Enhanced Prompting Strategy ---

# **FIX 1: Define STRIDE categories explicitly to ensure systematic coverage.**
STRIDE_DEFINITIONS = {
    "S": ("Spoofing", "Illegitimately accessing systems or data by impersonating a user, process, or component."),
    "T": ("Tampering", "Unauthorized modification of data, either in transit or at rest."),
    "R": ("Repudiation", "A user or system denying that they performed an action, often due to a lack of sufficient proof (e.g., logs)."),
    "I": ("Information Disclosure", "Exposing sensitive information to unauthorized individuals."),
    "D": ("Denial of Service", "Preventing legitimate users from accessing a system or service."),
    "E": ("Elevation of Privilege", "A user or process gaining rights beyond their authorized level.")
}

# **FIX 2: Create a highly specific prompt template focused on a SINGLE STRIDE category.**
# This prevents generic responses and forces the model to generate relevant, accurate threats.
threat_prompt_template_specific_rag = """
You are a cybersecurity architect specializing in threat modeling using the STRIDE methodology.
Your task is to generate 1-2 specific threats for a given DFD component, focusing ONLY on a single STRIDE category.

**DFD Component to Analyze:**
{component_info}

**STRIDE Category to Focus On:**
- **{stride_category} ({stride_name}):** {stride_definition}

**Security Context from Knowledge Base (for accuracy):**
'''
{rag_context}
'''

**Instructions:**
1.  Generate 1-2 distinct and realistic threats for the component that fall **strictly** under the '{stride_name}' category.
2.  **Be specific.** Relate the threat directly to the component's type and details. For a database, a Spoofing threat is a spoofed connection, not user impersonation. For a data flow, a Tampering threat is a Man-in-the-Middle attack.
3.  Use the provided Security Context to create specific descriptions, **actionable mitigations**, and accurate references (e.g., CWE, OWASP Cheat Sheets). Do not invent references.
4.  Provide a realistic risk assessment (Impact, Likelihood, Score).
5.  Output ONLY a valid JSON object with a single key "threats", containing a list of threat objects. Do not include any other text or commentary.

**JSON Threat Object Schema:**
{{
  "component_name": "string (the name of the component being analyzed)",
  "stride_category": "{stride_category}",
  "threat_description": "string (Specific to the component and STRIDE category)",
  "mitigation_suggestion": "string (Actionable and specific)",
  "impact": "Low, Medium, or High",
  "likelihood": "Low, Medium, or High",
  "references": ["list of strings, e.g., 'OWASP A01:2021', 'CWE-89'"],
  "risk_score": "Critical, High, Medium, or Low"
}}
"""

# --- Main Invocation Logic ---
logger.info("\n--- Invoking LLM with RAG to systematically generate STRIDE threats ---")
all_threats = []
try:
    components_to_analyze = []
    for key, value in dfd_data.items():
        if isinstance(value, list) and value:
            for item in value:
                # Ensure component has a name for better identification
                if isinstance(item, dict) and item.get("name"):
                    components_to_analyze.append({"type": key, "details": item})
                elif isinstance(item, dict): # Fallback for components without a 'name' field
                    components_to_analyze.append({"type": key, "details": item})


    # **FIX 3: Iterate through each component AND each STRIDE category.**
    # This loop structure ensures every category is considered for every component.
    for component in components_to_analyze:
        component_str = json.dumps(component)
        component_name = component.get("details", {}).get("name", component_str)
        logger.info(f"\n--- Analyzing component: {component_name} ---")

        retrieved_docs = rag_db.similarity_search(component_str, k=3)
        rag_context = "\n---\n".join([doc.page_content for doc in retrieved_docs])
        logger.info("--- Retrieved RAG context for component ---")

        for cat_letter, (cat_name, cat_def) in STRIDE_DEFINITIONS.items():
            logger.info(f"--- Generating threats for STRIDE category: {cat_name} ---")
            
            prompt = threat_prompt_template_specific_rag.format(
                component_info=component_str,
                rag_context=rag_context,
                stride_category=cat_letter,
                stride_name=cat_name,
                stride_definition=cat_def
            )
            
            try:
                response = client.chat.completions.create(
                    model=LLM_MODEL,
                    messages=[{"role": "user", "content": prompt}],
                    response_format={"type": "json_object"},
                    max_tokens=2048,
                    temperature=0.4 # Slightly lower temp for more focused output
                )
                
                response_content = response.choices[0].message.content
                generated_data = json.loads(response_content)
                
                # Ensure the response is a dict with a 'threats' key which is a list
                if isinstance(generated_data, dict) and isinstance(generated_data.get("threats"), list):
                    threats = generated_data["threats"]
                    # Add component name if missing from LLM response
                    for threat in threats:
                        if 'component_name' not in threat or not threat['component_name']:
                            threat['component_name'] = component_name
                    all_threats.extend(threats)
                    logger.info(f"--- Successfully generated {len(threats)} threat(s) for category {cat_name} ---")
                else:
                    logger.warning(f"--- LLM response for {cat_name} on {component_name} had unexpected structure. ---")
                    logger.debug(f"Raw Response: {response_content}")

            except (json.JSONDecodeError, AttributeError) as e:
                logger.warning(f"--- Could not parse LLM response for {cat_name} on {component_name}: {e} ---")
            except Exception as e:
                logger.error(f"--- An API error occurred for {cat_name} on {component_name}: {e} ---")

    # --- Final Processing and Validation ---
    risk_order = {"Critical": 4, "High": 3, "Medium": 2, "Low": 1, "Informational": 0}
    all_threats.sort(key=lambda t: risk_order.get(t.get('risk_score', 'Low'), 0), reverse=True)

    final_output = {
        "threats": all_threats,
        "metadata": {
            "timestamp": datetime.now().isoformat(),
            "source_dfd": os.path.basename(DFD_INPUT_PATH),
            "llm_model": LLM_MODEL,
            "rag_index": FAISS_INDEX_PATH
        }
    }

    try:
        validated_output = ThreatsOutput(**final_output)
        logger.info("--- Final JSON output validated successfully against schema ---")
    except ValidationError as ve:
        logger.error(f"--- FINAL JSON VALIDATION FAILED: {ve} ---")
        
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(final_output, f, indent=2)

    logger.info("\n--- LLM RAG Output (Identified Threats) ---")
    # print(json.dumps(final_output, indent=2))
    logger.info(f"\n--- Identified {len(all_threats)} threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during the threat generation process ---")
    logger.error(f"Error: {e}", exc_info=True)

2025-07-28 13:18:01,977 - INFO - --- Setting up RAG pipeline ---
2025-07-28 13:18:01,978 - INFO - Use pytorch device_name: mps
2025-07-28 13:18:01,978 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
2025-07-28 13:18:04,729 - INFO - --- No existing FAISS index found. Building a new one. ---
0it [00:00, ?it/s]
100%|██████████| 10/10 [00:00<00:00, 4631.52it/s]
0it [00:00, ?it/s]
2025-07-28 13:18:04,737 - INFO - --- Creating FAISS index from 105 document chunks. This may take a moment... ---
  return forward_call(*args, **kwargs)
2025-07-28 13:18:06,856 - INFO - --- FAISS index created and saved to 'faiss_index' ---
2025-07-28 13:18:06,869 - INFO - --- OpenAI client initialized successfully ---
2025-07-28 13:18:06,871 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-28 13:18:06,871 - INFO - 
--- Invoking LLM with RAG to generate STRIDE threats ---
2025-07-28 13:18:06,871 - INFO - 
--- Analyzing component: {"type": "ext

{
  "threats": [
    {
      "component_name": "U",
      "stride_category": "I",
      "threat_description": "Improper Restriction of XML External Entity Reference (CWE-611) could allow an attacker to access sensitive information by manipulating XML entities.",
      "mitigation_suggestion": "Implement proper restriction of XML external entity references and use secure parsing methods.",
      "impact": "High",
      "likelihood": "Medium",
      "references": [
        "CWE-611"
      ],
      "risk_score": "High"
    },
    {
      "component_name": "DB_P",
      "stride_category": "T",
      "threat_description": "Server-Side Request Forgery (SSRF) attacks can be launched against the DB_P component, allowing attackers to manipulate database queries and extract sensitive data.",
      "mitigation_suggestion": "Implement proper input validation and sanitization to prevent malicious requests from being processed.",
      "impact": "High",
      "likelihood": "Medium",
      "reference