In [17]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install langchain langchain-community langchain-ollama python-dotenv pydantic logging instructor

import os
import json
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
import instructor
from ollama import Client  # Added for raw response debugging

# Load environment variables
load_dotenv()

# --- Configuration ---
# Use environment variables for paths and settings
LLM_MODEL = os.getenv("LLM_MODEL", "llama3:8b")  # Changed default to a more reliable model for testing
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists early
os.makedirs(INPUT_DIR, exist_ok=True)

# Initialize LLM with Instructor for schema enforcement
llm = instructor.from_provider(f"ollama/{LLM_MODEL}", mode=instructor.Mode.JSON_SCHEMA)

# Added: Raw Ollama client for debugging
ollama_client = Client()

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(description="Affected asset, process, data flow, or entity.")
    stride_category: str = Field(description="One STRIDE category: Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege.")
    threat_description: str = Field(description="Clear, specific description of the threat.")
    mitigation_suggestion: str = Field(description="Practical, actionable mitigation.")
    impact: str = Field(description="Low/Medium/High based on potential damage.")
    likelihood: str = Field(description="Low/Medium/High based on exploitability.")
    references: list[str] = Field(description="Array of standard references (e.g., ['OWASP A01:2021', 'NIST SI-2']).")

class Threats(BaseModel):
    threats: list[Threat]

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing (if input is empty) ---
SAMPLE_DFD = {
    "external_entities": ["User", "Attacker"],
    "processes": ["Web Application", "Authentication Service"],
    "data_stores": ["User Database"],
    "data_flows": [
        {
            "from": "User",
            "to": "Web Application",
            "data": "Login Credentials",
            "protocol": "HTTP"
        },
        {
            "from": "Web Application",
            "to": "User Database",
            "data": "Query User Data",
            "protocol": "SQL"
        }
    ],
    "trust_boundaries": ["Internet to DMZ", "DMZ to Internal Network"]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:  # Added: Check for empty data
        logger.warning("--- DFD data is empty. Using sample DFD for testing ---")
        dfd_data = SAMPLE_DFD
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.warning(f"--- Input file not found at '{DFD_INPUT_PATH}'. Using sample DFD for testing ---")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty. Using sample DFD for testing.")
    dfd_data = SAMPLE_DFD
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Prompt Engineering for Threat Generation ---
threat_prompt_template = """
You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where applicable.
2. Describe threats considering trust boundaries, protocols, and potential attack vectors (e.g., injection, misconfiguration).
3. Suggest mitigations with references to standards (e.g., "NIST AC-6 for least privilege").
4. Assess impact (Low/Medium/High based on potential damage) and likelihood (Low/Medium/High based on exploitability).

For each threat, include:
- 'component_name': Affected asset, process, data flow, or entity.
- 'stride_category': One STRIDE category.
- 'threat_description': Clear, specific description (e.g., "Attacker intercepts unencrypted data in transit leading to disclosure").
- 'mitigation_suggestion': Practical, actionable mitigation (e.g., "Implement TLS 1.3 with certificate pinning").
- 'impact': Low/Medium/High.
- 'likelihood': Low/Medium/High.
- 'references': Array of strings (e.g., ["OWASP A01:2021", "NIST SI-2"]).

DFD Components:
---
{dfd_json}
---

Generate a JSON object with a key 'threats' (array of threat objects). Output ONLY the JSON, with no additional commentary or formatting.
"""

threat_prompt = ChatPromptTemplate.from_template(threat_prompt_template)

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM to generate STRIDE threats ---")
try:
    # Convert the loaded DFD dictionary back to a JSON string for the prompt
    dfd_json_string = json.dumps(dfd_data, indent=2)

    # Generate messages from the prompt template
    messages = threat_prompt.format_messages(dfd_json=dfd_json_string)

    # Added: Log the prompt for debugging
    logger.info(f"--- Prompt sent to LLM ---\n{messages[0].content}")

    # Added: Call raw Ollama for response debugging (before Instructor)
    raw_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": messages[0].content}])
    logger.info(f"--- Raw LLM Response ---\n{raw_response['message']['content']}")

    # Invoke the LLM with Instructor for structured output
    threats_obj = llm.chat.completions.create(
        messages=[{"role": "user", "content": messages[0].content}],
        response_model=Threats,
        max_retries=5  # Increased for better handling
    )

    threats_dict = threats_obj.model_dump()
    
    # Add metadata
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }
    
    # Validate the output against schema (Instructor already enforces, but double-check)
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        raise
    
    # Save the threats to a new file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)
        
    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to the LLM not returning a well-formed JSON object or an issue with the input data.")

2025-07-27 19:49:51,081 - INFO - Initializing ollama provider with model llama3:8b
2025-07-27 19:49:51,103 - INFO - Client initialized
2025-07-27 19:49:51,118 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-27 19:49:51,119 - INFO - --- DFD components loaded successfully ---
2025-07-27 19:49:51,119 - INFO - 
--- Invoking Local LLM to generate STRIDE threats ---
2025-07-27 19:49:51,120 - INFO - --- Prompt sent to LLM ---

You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where 

{
  "threats": [
    {
      "component_name": "U",
      "stride_category": "Information Disclosure",
      "threat_description": "Unencrypted data transmitted from U to CDN potentially disclosed",
      "mitigation_suggestion": "Implement end-to-end encryption; Use HTTPS protocol",
      "impact": "Medium",
      "likelihood": "High",
      "references": [
        "OWASP A01:2021",
        "NIST SI-2"
      ]
    },
    {
      "component_name": "CDN",
      "stride_category": "Tampering",
      "threat_description": "Malicious actor intercepts and modifies data in transit from CDN to LB",
      "mitigation_suggestion": "Implement integrity checking; Use digital signatures",
      "impact": "High",
      "likelihood": "Medium",
      "references": [
        "OWASP A03:2021",
        "MITRE CA-8"
      ]
    },
    {
      "component_name": "LB",
      "stride_category": "Elevation of Privilege",
      "threat_description": "Unprivileged actor gains elevated privileges on LB, potentia

In [34]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install instructor openai pydantic logging python-dotenv

import os
import json
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
from ollama import Client  # For raw debugging
import instructor
from openai import OpenAI  # Wrapper for Ollama

# Load environment variables
load_dotenv()

# --- Configuration ---
LLM_MODEL = os.getenv("LLM_MODEL", "llama3-70b-m3max:latest")
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists
os.makedirs(INPUT_DIR, exist_ok=True)

# Initialize Ollama client with Instructor (using OpenAI wrapper)
try:
    client = instructor.from_openai(
        OpenAI(
            base_url="http://aioverlord:11434/v1",
            api_key="ollama",  # Required but unused
        ),
        mode=instructor.Mode.JSON,  # Use JSON mode for Ollama
    )
    logger.info("--- Ollama client initialized successfully on port 11434 ---")
except Exception as e:
    logger.error(f"--- Failed to initialize Ollama client on port 11434: {e} ---")
    raise

# Raw Ollama client for debugging
try:
    ollama_client = Client(host="http://aioverlord:11434")
    logger.info("--- Raw Ollama client initialized successfully on port 11434 ---")
except Exception as e:
    logger.error(f"--- Failed to initialize raw Ollama client on port 11434: {e} ---")
    raise

# Health check: Ping the server by listing models (or any lightweight endpoint)
try:
    models_response = ollama_client.list()
    logger.info(f"--- Ollama server health check successful on port 11434. Available models: {models_response.get('models', 'None listed')} ---")
except Exception as e:
    logger.error(f"--- Ollama server health check failed on port 11434: {e} ---")
    raise

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(description="Affected asset, process, data flow, or entity.")
    stride_category: str = Field(description="One letter: S, T, R, I, D, or E.")  # Aligned with prompt
    threat_description: str = Field(description="Clear, specific description of the threat, including attack vectors and consequences.")
    mitigation_suggestion: str = Field(description="Practical, actionable mitigation.")
    impact: str = Field(description="Low, Medium, or High based on potential damage.")
    likelihood: str = Field(description="Low, Medium, or High based on exploitability.")
    references: list[str] = Field(description="Array of 1-3 valid references (e.g., ['OWASP A01:2021', 'NIST SP 800-53 SC-28', 'CWE-89']).")
    risk_score: str = Field(description="Critical, High, Medium, or Low (calculated from impact and likelihood).")

class Threats(BaseModel):
    threats: list[Threat]

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing (if input is empty) ---
SAMPLE_DFD = {
    "external_entities": ["User", "Attacker"],
    "processes": ["Web Application", "Authentication Service"],
    "data_stores": ["User Database"],
    "data_flows": [
        {
            "source": "User",
            "destination": "Web Application",
            "data_description": "Login Credentials",
            "protocol": "HTTP"
        },
        {
            "source": "Web Application",
            "destination": "User Database",
            "data_description": "Query User Data",
            "protocol": "SQL"
        }
    ],
    "trust_boundaries": ["Internet to DMZ", "DMZ to Internal Network"]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:
        logger.warning("--- DFD data is empty. Using sample DFD for testing ---")
        dfd_data = SAMPLE_DFD
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.warning(f"--- Input file not found at '{DFD_INPUT_PATH}'. Using sample DFD for testing ---")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty. Using sample DFD for testing.")
    dfd_data = SAMPLE_DFD
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Improved Prompt Engineering for Threat Generation ---
threat_prompt_template = """
You are a cybersecurity architect specializing in threat modeling using the STRIDE framework. Your task is to generate a complete list of relevant threats based on the provided DFD components in JSON format.

STRIDE categories must be exactly one of: 
- S: Spoofing
- T: Tampering
- R: Repudiation
- I: Information Disclosure
- D: Denial of Service
- E: Elevation of Privilege

Do not use any other categories or variations. Generate threats covering ALL components, including:
- External entities (e.g., "User")
- Processes (e.g., "Web Application")
- Assets (data stores, e.g., "User Database")
- Data flows (label as "Source to Destination", e.g., "User to Web Application")
- Trust boundaries (e.g., "Internet to DMZ")

For each component, explicitly consider all six STRIDE categories (S, T, R, I, D, E) and generate threats for each where realistically applicable, aiming for at least 5-6 threats per component to ensure broad and comprehensive coverage. Justify omissions internally if a category truly does not apply, but prioritize inclusion to avoid gaps.

For each threat, output strictly in this JSON structure:
{{
  "component_name": "Affected component (e.g., 'User to Web Application' or 'User Database')",
  "stride_category": "One letter: S, T, R, I, D, or E",
  "threat_description": "Clear, specific description tied to the component, including attack vectors (e.g., 'via MITM on weak TLS leading to PII exposure and regulatory fines'). Reference protocols from the DFD (e.g., HTTPS, AMQP).",
  "mitigation_suggestion": "Practical, actionable mitigation (e.g., 'Implement HTTPS with certificate pinning and HSTS').",
  "impact": "Low, Medium, or High",
  "likelihood": "Low, Medium, or High",
  "references": ["Array of 1-3 valid strings (e.g., 'OWASP A01:2021', 'NIST SP 800-53 SC-28', 'CWE-89')"],
  "risk_score": "Critical, High, Medium, or Low (calculate as: Critical if Impact=High and Likelihood=Medium/High; High if Impact=High and Likelihood=Low or Impact=Medium and Likelihood=High; Medium if Impact=Medium and Likelihood=Medium/Low or Impact=Low and Likelihood=High; Low otherwise)"
}}

Think step-by-step internally before generating:
1. Parse the DFD JSON: Explicitly list all external_entities, processes, assets (data_stores), data_flows (as 'source to destination'), and trust_boundaries. Ensure every item is addressed.
2. For each listed component, brainstorm applicable STRIDE threats: Consider all six categories, identifying typical risks (e.g., data flows vulnerable to T and I due to transit; processes to D and E from overload or vulns; data stores to I at rest). Generate at least one threat per category unless impossible, to achieve 5-6 per component.
3. Make threats realistic and specific: Include attack vectors (e.g., SQL injection for tampering), consequences (e.g., data breach leading to fines), and reference DFD protocols/architecture.
4. Avoid duplicates: Ensure no identical threats across components; vary descriptions even for similar risks.
5. Assign impact/likelihood with justification: Base on exposure (e.g., public-facing = higher likelihood) and calculate risk_score accurately.
6. Use only valid references from: OWASP Top 10 2021 (A01-A10 only, e.g., A03:2021 for Injection), NIST SP 800-series (e.g., 800-53, 800-63B, 800-52), CWE (e.g., CWE-89), MITRE ATT&CK (e.g., T1071). Limit to 1-3 per threat; do not invent invalid ones like A11.

Examples of good threats:
{{
  "component_name": "User to Web Application",
  "stride_category": "S",
  "threat_description": "Attacker spoofs user identity via phishing to send fake login credentials over HTTP, leading to account takeover and unauthorized access.",
  "mitigation_suggestion": "Implement strong authentication such as OAuth 2.0 with JWT tokens and MFA.",
  "impact": "High",
  "likelihood": "Medium",
  "references": ["OWASP A05:2021", "NIST SP 800-63B", "CWE-287"],
  "risk_score": "Critical"
}},
{{
  "component_name": "User Database",
  "stride_category": "I",
  "threat_description": "Unauthorized access via SQL injection leading to data leakage of sensitive PII, resulting in privacy violations and fines.",
  "mitigation_suggestion": "Encrypt data at rest using AES-256, implement parameterized queries, and enforce least-privilege access controls.",
  "impact": "High",
  "likelihood": "Medium",
  "references": ["OWASP A04:2021", "NIST SP 800-53 SC-28", "CWE-200"],
  "risk_score": "Critical"
}}

Negative examples to avoid:
- Invalid reference like "OWASP A11:2021" (Top 10 only has A01-A10).
- Generic description like "Unauthorized access to database leading to data leakage" (add vectors/consequences).
- Missing STRIDE coverage without internal justification.
- Fewer than 5 threats per component, leading to gaps.

DFD Components JSON:
{dfd_json}

Output ONLY a JSON object with:
- "threats": [array of threat objects, sorted by risk_score descending (Critical first, then High, Medium, Low)]

Do not include metadata or any other keys. Output ONLY the JSON, with no additional text, commentary, reasoning, or formatting.

"""

# --- Validation Prompt Template ---
validation_prompt_template = """
You are a JSON validator for threat modeling outputs. Your task is to validate and correct the following threats JSON to ensure:
- It is valid JSON.
- Each threat matches the required schema: component_name, stride_category (S/T/R/I/D/E), threat_description, mitigation_suggestion, impact (Low/Medium/High), likelihood (Low/Medium/High), references (list of 1-3 strings), risk_score (Critical/High/Medium/Low).
- No duplicates.
- Risk scores are correctly calculated based on impact and likelihood.
- References are valid (from OWASP, NIST, CWE, MITRE).
- Threats are sorted by risk_score descending (Critical > High > Medium > Low).

If invalid or improvable, correct it. If valid, return the original.

Input Threats JSON:
{threats_json}

Output ONLY the corrected JSON object with key "threats": [array of threat objects]. No other text or explanations.
"""

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM to generate STRIDE threats ---")
try:
    # Prepare the generation prompt with DFD JSON
    dfd_json_string = json.dumps(dfd_data, indent=2)
    gen_prompt = threat_prompt_template.format(dfd_json=dfd_json_string)

    # Log the prompt for debugging
    logger.info(f"--- Generation Prompt sent to LLM ---\n{gen_prompt}")

    # Call raw Ollama for response debugging
    raw_gen_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": gen_prompt}])
    logger.info(f"--- Raw LLM Generation Response ---\n{raw_gen_response['message']['content']}")

    # Invoke with Instructor for structured output
    threats_obj = client.chat.completions.create(
        model=LLM_MODEL,
        messages=[{"role": "user", "content": gen_prompt}],
        response_model=Threats,
        max_retries=5  # Retries for validation failures
    )
    logger.info("--- Structured threat generation successful ---")

    threats_dict = threats_obj.model_dump()

    # Prepare validation prompt
    threats_json_string = json.dumps(threats_dict, indent=2)
    val_prompt = validation_prompt_template.format(threats_json=threats_json_string)

    # Log the validation prompt
    logger.info(f"--- Validation Prompt sent to LLM ---\n{val_prompt}")

    # Call raw Ollama for validation response
    raw_val_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": val_prompt}])
    logger.info(f"--- Raw LLM Validation Response ---\n{raw_val_response['message']['content']}")

    # Parse validated response
    try:
        validated_threats = json.loads(raw_val_response['message']['content'])
        threats_dict = {"threats": validated_threats.get("threats", threats_dict["threats"])}
        logger.info("--- Threats validated and parsed successfully ---")
    except json.JSONDecodeError:
        logger.warning("--- Validation parsing failed; using original threats ---")

    # Add metadata manually
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }

    # Validate against full schema
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully against schema ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        # Fallback: Try manual parsing of raw generation response
        try:
            raw_json = json.loads(raw_gen_response['message']['content'])
            threats_dict = {"threats": raw_json.get("threats", []), "metadata": threats_dict["metadata"]}
            validated = ThreatsOutput(**threats_dict)
            logger.info("--- Fallback manual parsing succeeded ---")
        except Exception as pe:
            logger.error(f"--- Manual parsing failed: {pe} ---")
            raise

    # Save to file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)

    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to Ollama not enforcing JSON strictly or connection issues on port 5080. Check raw response and adjust prompt.")

2025-07-28 11:08:42,199 - INFO - --- Ollama client initialized successfully on port 11434 ---
2025-07-28 11:08:42,214 - INFO - --- Raw Ollama client initialized successfully on port 11434 ---
2025-07-28 11:08:42,228 - INFO - HTTP Request: GET http://aioverlord:11434/api/tags "HTTP/1.1 200 OK"
2025-07-28 11:08:42,228 - INFO - --- Ollama server health check successful on port 11434. Available models: [Model(model='llama3-70b-m3max:latest', modified_at=datetime.datetime(2025, 7, 28, 11, 7, 1, 64630, tzinfo=TzInfo(+02:00)), digest='11a9abedc8a182544453e5f23f6f425ed0d551200cdaba1aef62e626982d5c05', size=39969745456, details=ModelDetails(parent_model='', format='gguf', family='llama', families=['llama'], parameter_size='70.6B', quantization_level='Q4_0')), Model(model='llama3:70b-instruct', modified_at=datetime.datetime(2025, 7, 28, 10, 7, 48, 277784, tzinfo=TzInfo(+02:00)), digest='786f3184aec0e907952488b865362bdaa38180739a9881a8190d85bad8cab893', size=39969745349, details=ModelDetails(pare

In [None]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install openai pydantic logging python-dotenv

import os
import json
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
from openai import OpenAI

# Load environment variables
load_dotenv()

# --- Configuration ---
LLM_MODEL = os.getenv("LLM_MODEL", "llama-3.3-70b-instruct")
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists
os.makedirs(INPUT_DIR, exist_ok=True)

# Initialize OpenAI client
try:
    client = OpenAI(
        base_url="https://api.scaleway.ai/4a8fd76b-8606-46e6-afe6-617ce8eeb948/v1",
        api_key=os.getenv("SCW_SECRET_KEY")
    )
    logger.info("--- OpenAI client initialized successfully ---")
except Exception as e:
    logger.error(f"--- Failed to initialize OpenAI client: {e} ---")
    raise

# Health check: Ping the server by listing models
try:
    models_response = client.models.list()
    available_models = [model.id for model in models_response.data]
    logger.info(f"--- Server health check successful. Available models: {available_models} ---")
except Exception as e:
    logger.error(f"--- Server health check failed: {e} ---")
    raise

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(description="Affected asset, process, data flow, or entity.")
    stride_category: str = Field(description="One letter: S, T, R, I, D, or E.")  # Aligned with prompt
    threat_description: str = Field(description="Clear, specific description of the threat, including attack vectors and consequences.")
    mitigation_suggestion: str = Field(description="Practical, actionable mitigation.")
    impact: str = Field(description="Low, Medium, or High based on potential damage.")
    likelihood: str = Field(description="Low, Medium, or High based on exploitability.")
    references: list[str] = Field(description="Array of 1-3 valid references (e.g., ['OWASP A01:2021', 'NIST SP 800-53 SC-28', 'CWE-89']).")
    risk_score: str = Field(description="Critical, High, Medium, or Low (calculated from impact and likelihood).")

class Threats(BaseModel):
    threats: list[Threat]

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing (if input is empty) ---
SAMPLE_DFD = {
    "external_entities": ["User", "Attacker"],
    "processes": ["Web Application", "Authentication Service"],
    "data_stores": ["User Database"],
    "data_flows": [
        {
            "source": "User",
            "destination": "Web Application",
            "data_description": "Login Credentials",
            "protocol": "HTTP"
        },
        {
            "source": "Web Application",
            "destination": "User Database",
            "data_description": "Query User Data",
            "protocol": "SQL"
        }
    ],
    "trust_boundaries": ["Internet to DMZ", "DMZ to Internal Network"]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:
        logger.warning("--- DFD data is empty. Using sample DFD for testing ---")
        dfd_data = SAMPLE_DFD
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.warning(f"--- Input file not found at '{DFD_INPUT_PATH}'. Using sample DFD for testing ---")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty. Using sample DFD for testing.")
    dfd_data = SAMPLE_DFD
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Improved Prompt Engineering for Threat Generation ---
threat_prompt_template = """
STRIDE categories must be exactly one of:

S: Spoofing
T: Tampering
R: Repudiation
I: Information Disclosure
D: Denial of Service
E: Elevation of Privilege
Do not use any other categories or variations. Generate threats covering ALL components, including:

External entities (e.g., "User")
Processes (e.g., "Web Application")
Assets (data stores, e.g., "User Database")
Data flows (label as "Source to Destination", e.g., "User to Web Application")
Trust boundaries (e.g., "Internet to DMZ")
For each component, explicitly consider all six STRIDE categories (S, T, R, I, D, E) and generate at least one threat for each category where realistically applicable. Aim for exactly 5-6 threats per component (one per applicable STRIDE category, omitting only if truly impossible—e.g., Repudiation may not apply to a read-only data flow). Justify any omissions internally with reasoning (e.g., "R does not apply to this trust boundary as no actions are attributable"). This ensures broad coverage without gaps; total threats should be approximately 5-6 times the number of components (e.g., for 20 components, aim for 100-120 threats).

For each threat, output strictly in this JSON structure:
{{
"component_name": "Affected component (e.g., 'User to Web Application' or 'User Database')",
"stride_category": "One letter: S, T, R, I, D, or E",
"threat_description": "Clear, specific description tied to the component, including attack vectors (e.g., 'via MITM on weak TLS leading to PII exposure and regulatory fines'). Reference protocols from the DFD (e.g., HTTPS, AMQP).",
"mitigation_suggestion": "Practical, actionable mitigation (e.g., 'Implement HTTPS with certificate pinning and HSTS').",
"impact": "Low, Medium, or High",
"likelihood": "Low, Medium, or High",
"references": ["Array of 1-3 valid strings (e.g., 'OWASP A01:2021', 'NIST SP 800-53 SC-28', 'CWE-89')"],
"risk_score": "Critical, High, Medium, or Low (calculate as: Critical if Impact=High and Likelihood=Medium/High; High if Impact=High and Likelihood=Low or Impact=Medium and Likelihood=High; Medium if Impact=Medium and Likelihood=Medium/Low or Impact=Low and Likelihood=High; Low otherwise)"
}}

Think step-by-step internally before generating:

Parse the DFD JSON: Explicitly list ALL external_entities, processes, assets (data_stores), data_flows (as 'Source to Destination'), and trust_boundaries. Confirm the full list (e.g., "External: U; Assets: DB_P, DB_B; ...") to ensure nothing is missed.
For each listed component, brainstorm applicable STRIDE threats: Consider all six categories systematically. Identify typical risks (e.g., data flows vulnerable to T and I due to transit; processes to D and E from overload or vulns; data stores to I at rest). Generate at least one threat per category unless justified (e.g., "Omitting D for this internal boundary as it's not exposed to volumetric attacks").
Make threats realistic and specific: Include attack vectors (e.g., SQL injection for tampering), consequences (e.g., data breach leading to fines), and reference DFD protocols/architecture.
Assign impact/likelihood with justification: Base on exposure (e.g., public-facing = higher likelihood; sensitive data = higher impact). For each threat, internally note: "Impact: High because PII involved; Likelihood: Medium due to common MITM tools."
Calculate risk_score accurately: Use a mental table like:
High + High = Critical
High + Medium = Critical
High + Low = High
Medium + High = High
Medium + Medium = Medium
Etc. Verify no mismatches.
Avoid duplicates: Ensure no identical threats across components; vary descriptions even for similar risks (e.g., different vectors for MITM).
Use only valid references from: OWASP Top 10 2021 (A01-A10 only, e.g., A03:2021 for Injection), NIST SP 800-series (e.g., 800-53, 800-63B, 800-52), CWE (e.g., CWE-89), MITRE ATT&CK (e.g., T1071). Limit to 1-3 per threat; cross-check validity.
Verify completeness: Count threats per component (aim 5-6), ensure all components covered, and confirm sorting by risk_score (Critical > High > Medium > Low).
Examples of good threats:
[Same as original]

Negative examples to avoid:

Invalid reference like "OWASP A11:2021" (Top 10 only has A01-A10).
Generic description like "Unauthorized access to database leading to data leakage" (add vectors/consequences).
Missing STRIDE coverage without internal justification.
Fewer than 5 threats per component, leading to gaps.
Risk miscalculation, e.g., Impact=High, Likelihood=Medium listed as "High" instead of "Critical".
DFD Components JSON:
{dfd_json}

Output ONLY a JSON object with:

"threats": [array of threat objects, sorted by risk_score descending (Critical first, then High, Medium, Low)]
Do not include metadata or any other keys. Output ONLY the JSON, with no additional text, commentary, reasoning, or formatting.
"""

# --- Validation Prompt Template ---
validation_prompt_template = """
You are a JSON validator for threat modeling outputs. Your task is to validate and correct the following threats JSON to ensure:
- It is valid JSON.
- Each threat matches the required schema: component_name, stride_category (S/T/R/I/D/E), threat_description, mitigation_suggestion, impact (Low/Medium/High), likelihood (Low/Medium/High), references (list of 1-3 strings), risk_score (Critical/High/Medium/Low).
- No duplicates.
- Risk scores are correctly calculated based on impact and likelihood.
- References are valid (from OWASP, NIST, CWE, MITRE).
- Threats are sorted by risk_score descending (Critical > High > Medium > Low).

If invalid or improvable, correct it. If valid, return the original.

Input Threats JSON:
{threats_json}

Output ONLY the corrected JSON object with key "threats": [array of threat objects]. No other text or explanations.
"""

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM to generate STRIDE threats ---")
try:
    # Prepare the generation prompt with DFD JSON
    dfd_json_string = json.dumps(dfd_data, indent=2)
    gen_prompt = threat_prompt_template.format(dfd_json=dfd_json_string)

    # Log the prompt for debugging
    logger.info(f"--- Generation Prompt sent to LLM ---\n{gen_prompt}")

    # Call for response debugging
    raw_gen_response = client.chat.completions.create(
        model=LLM_MODEL,
        messages=[
            {"role": "system", "content": "You are a cybersecurity architect specializing in threat modeling using the STRIDE framework. Your task is to generate a complete list of relevant threats based on the provided DFD components in JSON format."},
            {"role": "user", "content": gen_prompt}
        ],
        response_format={"type": "json_object"},
        max_tokens=4000,  # Increased to handle large output
        temperature=0.6,
        top_p=0.9,
        presence_penalty=0
    )
    logger.info(f"--- Raw LLM Generation Response ---\n{raw_gen_response.choices[0].message.content}")

    # Parse to dict
    threats_dict = json.loads(raw_gen_response.choices[0].message.content)

    # Prepare validation prompt
    threats_json_string = json.dumps(threats_dict, indent=2)
    val_prompt = validation_prompt_template.format(threats_json=threats_json_string)

    # Log the validation prompt
    logger.info(f"--- Validation Prompt sent to LLM ---\n{val_prompt}")

    # Call for validation response
    raw_val_response = client.chat.completions.create(
        model=LLM_MODEL,
        messages=[
            {"role": "system", "content": "You are a helpful assistant"},
            {"role": "user", "content": val_prompt}
        ],
        response_format={"type": "json_object"},
        max_tokens=2000,  # Increased for validation
        temperature=0.6,
        top_p=0.9,
        presence_penalty=0
    )
    logger.info(f"--- Raw LLM Validation Response ---\n{raw_val_response.choices[0].message.content}")

    # Parse validated response
    try:
        validated_threats = json.loads(raw_val_response.choices[0].message.content)
        threats_dict = {"threats": validated_threats.get("threats", threats_dict["threats"])}
        logger.info("--- Threats validated and parsed successfully ---")
    except json.JSONDecodeError:
        logger.warning("--- Validation parsing failed; using original threats ---")

    # Add metadata manually
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }

    # Validate against full schema
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully against schema ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        # Fallback: Try manual parsing of raw generation response
        try:
            raw_json = json.loads(raw_gen_response.choices[0].message.content)
            threats_dict = {"threats": raw_json.get("threats", []), "metadata": threats_dict["metadata"]}
            validated = ThreatsOutput(**threats_dict)
            logger.info("--- Fallback manual parsing succeeded ---")
        except Exception as pe:
            logger.error(f"--- Manual parsing failed: {pe} ---")
            raise

    # Save to file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)

    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to the model not enforcing JSON strictly or connection issues. Check raw response and adjust prompt.")

2025-07-28 11:30:16,588 - INFO - --- OpenAI client initialized successfully ---
2025-07-28 11:30:16,823 - INFO - HTTP Request: GET https://api.scaleway.ai/4a8fd76b-8606-46e6-afe6-617ce8eeb948/v1/models "HTTP/1.1 200 OK"
2025-07-28 11:30:16,825 - INFO - --- Server health check successful. Available models: ['deepseek-r1-distill-llama-70b', 'llama-3.3-70b-instruct', 'qwen2.5-coder-32b-instruct', 'mistral-nemo-instruct-2407', 'llama-3.1-8b-instruct', 'pixtral-12b-2409', 'mistral-small-3.1-24b-instruct-2503', 'gemma-3-27b-it', 'bge-multilingual-gemma2', 'devstral-small-2505'] ---
2025-07-28 11:30:16,827 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-28 11:30:16,829 - INFO - --- DFD components loaded successfully ---
2025-07-28 11:30:16,830 - INFO - 
--- Invoking Local LLM to generate STRIDE threats ---
2025-07-28 11:30:16,830 - INFO - --- Generation Prompt sent to LLM ---

STRIDE categories must be exactly one of: 
- S: Spoofing
- T: Tampering
- R: Repu

{
  "threats": [
    {
      "component_name": "User to CDN",
      "stride_category": "S",
      "threat_description": "Attacker spoofs user identity via phishing to send fake login credentials over HTTPS, leading to account takeover and unauthorized access.",
      "mitigation_suggestion": "Implement strong authentication such as OAuth 2.0 with JWT tokens and MFA.",
      "impact": "High",
      "likelihood": "Medium",
      "references": [
        "OWASP A05:2021",
        "NIST SP 800-63B",
        "CWE-287"
      ],
      "risk_score": "High"
    },
    {
      "component_name": "DB_P",
      "stride_category": "I",
      "threat_description": "Unauthorized access via SQL injection leading to data leakage of sensitive PII, resulting in privacy violations and fines.",
      "mitigation_suggestion": "Encrypt data at rest using AES-256, implement parameterized queries, and enforce least-privilege access controls.",
      "impact": "High",
      "likelihood": "Medium",
      "referenc