In [17]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install langchain langchain-community langchain-ollama python-dotenv pydantic logging instructor

import os
import json
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
import instructor
from ollama import Client  # Added for raw response debugging

# Load environment variables
load_dotenv()

# --- Configuration ---
# Use environment variables for paths and settings
LLM_MODEL = os.getenv("LLM_MODEL", "llama3:8b")  # Changed default to a more reliable model for testing
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists early
os.makedirs(INPUT_DIR, exist_ok=True)

# Initialize LLM with Instructor for schema enforcement
llm = instructor.from_provider(f"ollama/{LLM_MODEL}", mode=instructor.Mode.JSON_SCHEMA)

# Added: Raw Ollama client for debugging
ollama_client = Client()

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(description="Affected asset, process, data flow, or entity.")
    stride_category: str = Field(description="One STRIDE category: Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege.")
    threat_description: str = Field(description="Clear, specific description of the threat.")
    mitigation_suggestion: str = Field(description="Practical, actionable mitigation.")
    impact: str = Field(description="Low/Medium/High based on potential damage.")
    likelihood: str = Field(description="Low/Medium/High based on exploitability.")
    references: list[str] = Field(description="Array of standard references (e.g., ['OWASP A01:2021', 'NIST SI-2']).")

class Threats(BaseModel):
    threats: list[Threat]

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing (if input is empty) ---
SAMPLE_DFD = {
    "external_entities": ["User", "Attacker"],
    "processes": ["Web Application", "Authentication Service"],
    "data_stores": ["User Database"],
    "data_flows": [
        {
            "from": "User",
            "to": "Web Application",
            "data": "Login Credentials",
            "protocol": "HTTP"
        },
        {
            "from": "Web Application",
            "to": "User Database",
            "data": "Query User Data",
            "protocol": "SQL"
        }
    ],
    "trust_boundaries": ["Internet to DMZ", "DMZ to Internal Network"]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:  # Added: Check for empty data
        logger.warning("--- DFD data is empty. Using sample DFD for testing ---")
        dfd_data = SAMPLE_DFD
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.warning(f"--- Input file not found at '{DFD_INPUT_PATH}'. Using sample DFD for testing ---")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty. Using sample DFD for testing.")
    dfd_data = SAMPLE_DFD
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Prompt Engineering for Threat Generation ---
threat_prompt_template = """
You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where applicable.
2. Describe threats considering trust boundaries, protocols, and potential attack vectors (e.g., injection, misconfiguration).
3. Suggest mitigations with references to standards (e.g., "NIST AC-6 for least privilege").
4. Assess impact (Low/Medium/High based on potential damage) and likelihood (Low/Medium/High based on exploitability).

For each threat, include:
- 'component_name': Affected asset, process, data flow, or entity.
- 'stride_category': One STRIDE category.
- 'threat_description': Clear, specific description (e.g., "Attacker intercepts unencrypted data in transit leading to disclosure").
- 'mitigation_suggestion': Practical, actionable mitigation (e.g., "Implement TLS 1.3 with certificate pinning").
- 'impact': Low/Medium/High.
- 'likelihood': Low/Medium/High.
- 'references': Array of strings (e.g., ["OWASP A01:2021", "NIST SI-2"]).

DFD Components:
---
{dfd_json}
---

Generate a JSON object with a key 'threats' (array of threat objects). Output ONLY the JSON, with no additional commentary or formatting.
"""

threat_prompt = ChatPromptTemplate.from_template(threat_prompt_template)

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM to generate STRIDE threats ---")
try:
    # Convert the loaded DFD dictionary back to a JSON string for the prompt
    dfd_json_string = json.dumps(dfd_data, indent=2)

    # Generate messages from the prompt template
    messages = threat_prompt.format_messages(dfd_json=dfd_json_string)

    # Added: Log the prompt for debugging
    logger.info(f"--- Prompt sent to LLM ---\n{messages[0].content}")

    # Added: Call raw Ollama for response debugging (before Instructor)
    raw_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": messages[0].content}])
    logger.info(f"--- Raw LLM Response ---\n{raw_response['message']['content']}")

    # Invoke the LLM with Instructor for structured output
    threats_obj = llm.chat.completions.create(
        messages=[{"role": "user", "content": messages[0].content}],
        response_model=Threats,
        max_retries=5  # Increased for better handling
    )

    threats_dict = threats_obj.model_dump()
    
    # Add metadata
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }
    
    # Validate the output against schema (Instructor already enforces, but double-check)
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        raise
    
    # Save the threats to a new file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)
        
    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to the LLM not returning a well-formed JSON object or an issue with the input data.")

2025-07-27 19:49:51,081 - INFO - Initializing ollama provider with model llama3:8b
2025-07-27 19:49:51,103 - INFO - Client initialized
2025-07-27 19:49:51,118 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-27 19:49:51,119 - INFO - --- DFD components loaded successfully ---
2025-07-27 19:49:51,119 - INFO - 
--- Invoking Local LLM to generate STRIDE threats ---
2025-07-27 19:49:51,120 - INFO - --- Prompt sent to LLM ---

You are a senior cybersecurity analyst specializing in threat modeling using the STRIDE methodology (Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, Elevation of Privilege), aligned with 2025 standards like OWASP Top 10, NIST SP 800-53, and MITRE ATT&CK.

Based on the provided Data Flow Diagram (DFD) components in JSON format, perform a comprehensive threat analysis. Use Chain-of-Thought reasoning:
1. For each external entity, asset, process, and data flow, systematically apply all STRIDE categories where 

{
  "threats": [
    {
      "component_name": "U",
      "stride_category": "Information Disclosure",
      "threat_description": "Unencrypted data transmitted from U to CDN potentially disclosed",
      "mitigation_suggestion": "Implement end-to-end encryption; Use HTTPS protocol",
      "impact": "Medium",
      "likelihood": "High",
      "references": [
        "OWASP A01:2021",
        "NIST SI-2"
      ]
    },
    {
      "component_name": "CDN",
      "stride_category": "Tampering",
      "threat_description": "Malicious actor intercepts and modifies data in transit from CDN to LB",
      "mitigation_suggestion": "Implement integrity checking; Use digital signatures",
      "impact": "High",
      "likelihood": "Medium",
      "references": [
        "OWASP A03:2021",
        "MITRE CA-8"
      ]
    },
    {
      "component_name": "LB",
      "stride_category": "Elevation of Privilege",
      "threat_description": "Unprivileged actor gains elevated privileges on LB, potentia

In [None]:
# --- Dependencies ---
# Ensure you have these packages installed. You can install them using pip:
# pip install instructor openai pydantic logging python-dotenv

import os
import json
from datetime import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
import logging
from ollama import Client  # For raw debugging
import instructor
from openai import OpenAI  # Wrapper for Ollama

# Load environment variables
load_dotenv()

# --- Configuration ---
LLM_MODEL = os.getenv("LLM_MODEL", "llama3:70b-instruct")
INPUT_DIR = os.getenv("INPUT_DIR", "./output")
DFD_INPUT_PATH = os.getenv("DFD_INPUT_PATH", os.path.join(INPUT_DIR, "dfd_components.json"))
THREATS_OUTPUT_PATH = os.getenv("THREATS_OUTPUT_PATH", os.path.join(INPUT_DIR, "identified_threats.json"))

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure output directory exists
os.makedirs(INPUT_DIR, exist_ok=True)

# Initialize Ollama client with Instructor (using OpenAI wrapper for better compatibility)
client = instructor.from_openai(
    OpenAI(
        base_url="http://localhost:11434/v1",  # Assuming Ollama runs locally
        api_key="ollama",  # Required but unused
    ),
    mode=instructor.Mode.JSON,  # Use JSON mode for Ollama; JSON_SCHEMA may not be fully supported
)

# Raw Ollama client for debugging
ollama_client = Client()

# --- Threat Schema for Validation ---
class Threat(BaseModel):
    component_name: str = Field(description="Affected asset, process, data flow, or entity.")
    stride_category: str = Field(description="One letter: S, T, R, I, D, or E.")  # Aligned with prompt
    threat_description: str = Field(description="Clear, specific description of the threat.")
    mitigation_suggestion: str = Field(description="Practical, actionable mitigation.")
    impact: str = Field(description="Low/Medium/High based on potential damage.")
    likelihood: str = Field(description="Low/Medium/High based on exploitability.")
    references: list[str] = Field(description="Array of standard references (e.g., ['OWASP A01:2021', 'NIST SI-2']).")

class Threats(BaseModel):
    threats: list[Threat]

class ThreatsOutput(BaseModel):
    threats: list[Threat]
    metadata: dict

# --- Sample DFD for Testing (if input is empty) ---
SAMPLE_DFD = {
    "external_entities": ["User", "Attacker"],
    "processes": ["Web Application", "Authentication Service"],
    "data_stores": ["User Database"],
    "data_flows": [
        {
            "from": "User",
            "to": "Web Application",
            "data": "Login Credentials",
            "protocol": "HTTP"
        },
        {
            "from": "Web Application",
            "to": "User Database",
            "data": "Query User Data",
            "protocol": "SQL"
        }
    ],
    "trust_boundaries": ["Internet to DMZ", "DMZ to Internal Network"]
}

# --- Load DFD Components ---
logger.info(f"--- Loading DFD components from '{DFD_INPUT_PATH}' ---")
try:
    with open(DFD_INPUT_PATH, 'r') as f:
        dfd_data = json.load(f)
    if not dfd_data:
        logger.warning("--- DFD data is empty. Using sample DFD for testing ---")
        dfd_data = SAMPLE_DFD
    logger.info("--- DFD components loaded successfully ---")
except FileNotFoundError:
    logger.warning(f"--- Input file not found at '{DFD_INPUT_PATH}'. Using sample DFD for testing ---")
    dfd_data = SAMPLE_DFD
except json.JSONDecodeError:
    logger.error(f"--- FATAL ERROR: Could not parse JSON from '{DFD_INPUT_PATH}' ---")
    logger.error("The file may be corrupted or empty. Using sample DFD for testing.")
    dfd_data = SAMPLE_DFD
except Exception as e:
    logger.error(f"--- FATAL ERROR: An unexpected error occurred while loading DFD components ---")
    logger.error(f"Error details: {e}")
    exit(1)

# --- Prompt Engineering for Threat Generation ---
threat_prompt_template = """
You are a cybersecurity architect specializing in threat modeling using the STRIDE framework. Your task is to generate a list of relevant threats based on the provided DFD components in JSON format.

STRIDE categories must be exactly one of: 
- S: Spoofing
- T: Tampering
- R: Repudiation
- I: Information Disclosure
- D: Denial of Service
- E: Elevation of Privilege

Do not use any other categories or variations. Focus on:
- External entities (e.g., "User")
- Processes (e.g., "Web Application")
- Data stores (e.g., "User Database")
- Data flows (label as "Source to Destination", e.g., "User to Web Application")
- Trust boundaries (e.g., "Internet to DMZ")

For each threat, output strictly in this JSON structure:
{{
  "component_name": "Affected component (e.g., 'User to Web Application' or 'User Database')",
  "stride_category": "One letter: S, T, R, I, D, or E",
  "threat_description": "Clear, specific description tied to the component (e.g., 'Attacker intercepts unencrypted data in transit leading to disclosure'). Reference protocols from the DFD (e.g., HTTP, SQL).",
  "mitigation_suggestion": "Practical, actionable mitigation (e.g., 'Implement HTTPS with certificate pinning').",
  "impact": "Low, Medium, or High",
  "likelihood": "Low, Medium, or High",
  "references": ["Array of 1-3 strings (e.g., 'OWASP A01:2021', 'NIST SP 800-53')"]
}}

Think step-by-step internally before generating:
1. Parse the DFD JSON: List all external_entities, processes, data_stores, data_flows, trust_boundaries.
2. For each component, identify applicable STRIDE threats based on typical risks (e.g., data flows vulnerable to I and T; processes to D and E).
3. Ensure threats are realistic and specific to the architecture (e.g., consider protocols like HTTP, SQL).
4. Avoid duplicates; vary across categories.
5. Use accurate references from OWASP Top 10, NIST, MITRE, etc.

Examples of good threats:
{{
  "component_name": "User to Web Application",
  "stride_category": "S",
  "threat_description": "Attacker spoofs user identity to send fake login credentials over HTTP.",
  "mitigation_suggestion": "Implement strong authentication such as OAuth 2.0 with JWT tokens.",
  "impact": "High",
  "likelihood": "Medium",
  "references": ["OWASP A05:2021", "NIST SP 800-63B"]
}},
{{
  "component_name": "User Database",
  "stride_category": "I",
  "threat_description": "Unauthorized access to database leading to data leakage via SQL queries.",
  "mitigation_suggestion": "Encrypt data at rest using AES-256 and implement access controls.",
  "impact": "High",
  "likelihood": "Medium",
  "references": ["OWASP A04:2021", "NIST SP 800-53 SC-28"]
}}

DFD Components JSON:
{dfd_json}

Output ONLY a JSON object with:
- "threats": [array of threat objects]

Do not include metadata or any other keys. Output ONLY the JSON, with no additional text, commentary, reasoning, or formatting.
"""

# --- Invocation and Output ---
logger.info("\n--- Invoking Local LLM to generate STRIDE threats ---")
try:
    # Prepare the prompt with DFD JSON
    dfd_json_string = json.dumps(dfd_data, indent=2)
    prompt = threat_prompt_template.format(dfd_json=dfd_json_string)

    # Log the prompt for debugging
    logger.info(f"--- Prompt sent to LLM ---\n{prompt}")

    # Call raw Ollama for response debugging
    raw_response = ollama_client.chat(model=LLM_MODEL, messages=[{"role": "user", "content": prompt}])
    logger.info(f"--- Raw LLM Response ---\n{raw_response['message']['content']}")

    # Invoke with Instructor for structured output
    threats_obj = client.chat.completions.create(
        model=LLM_MODEL,
        messages=[{"role": "user", "content": prompt}],
        response_model=Threats,
        max_retries=5  # Retries for validation failures
    )

    threats_dict = threats_obj.model_dump()
    
    # Add metadata manually
    threats_dict["metadata"] = {
        "timestamp": datetime.now().isoformat(),
        "source_dfd": DFD_INPUT_PATH
    }
    
    # Validate against full schema
    try:
        validated = ThreatsOutput(**threats_dict)
        logger.info("--- JSON output validated successfully ---")
    except ValidationError as ve:
        logger.error(f"--- JSON validation failed: {ve} ---")
        # Fallback: Try manual parsing of raw response
        try:
            raw_json = json.loads(raw_response['message']['content'])
            threats_dict = {"threats": raw_json.get("threats", []), "metadata": threats_dict["metadata"]}
            validated = ThreatsOutput(**threats_dict)
            logger.info("--- Fallback manual parsing succeeded ---")
        except Exception as pe:
            logger.error(f"--- Manual parsing failed: {pe} ---")
            raise
    
    # Save to file
    with open(THREATS_OUTPUT_PATH, 'w') as f:
        json.dump(threats_dict, f, indent=2)
        
    logger.info("\n--- LLM Output (Identified Threats) ---")
    print(json.dumps(threats_dict, indent=2))
    logger.info(f"\n--- Identified threats successfully saved to '{THREATS_OUTPUT_PATH}' ---")

except Exception as e:
    logger.error(f"\n--- An error occurred during threat generation ---")
    logger.error(f"Error: {e}")
    logger.error("This could be due to Ollama not enforcing JSON strictly. Check raw response and adjust prompt.")

2025-07-27 20:03:14,308 - INFO - --- Loading DFD components from './output/dfd_components.json' ---
2025-07-27 20:03:14,309 - INFO - --- DFD components loaded successfully ---
2025-07-27 20:03:14,309 - INFO - 
--- Invoking Local LLM to generate STRIDE threats ---
2025-07-27 20:03:14,310 - INFO - --- Prompt sent to LLM ---

You are a cybersecurity architect specializing in threat modeling using the STRIDE framework. Your task is to generate a list of relevant threats based on the provided DFD components in JSON format.

STRIDE categories must be exactly one of: 
- S: Spoofing
- T: Tampering
- R: Repudiation
- I: Information Disclosure
- D: Denial of Service
- E: Elevation of Privilege

Do not use any other categories or variations. Generate 10-15 threats to cover a range of components, ensuring balance across STRIDE categories. Focus on:
- External entities (e.g., "User")
- Processes (e.g., "Web Application")
- Data stores (e.g., "User Database")
- Data flows (label as "Source to Destin

{
  "threats": [
    {
      "component_name": "U to CDN",
      "stride_category": "S",
      "threat_description": "Attacker spoofs user identity to send fake login credentials over HTTPS.",
      "mitigation_suggestion": "Implement strong authentication such as OAuth 2.0 with JWT tokens.",
      "impact": "High",
      "likelihood": "Medium",
      "references": [
        "OWASP A05:2021",
        "NIST SP 800-63B"
      ]
    },
    {
      "component_name": "CDN to LB",
      "stride_category": "T",
      "threat_description": "Attacker tamper with data in transit via HTTPS protocol.",
      "mitigation_suggestion": "Implement HTTP header validation and input validation on the receiving end.",
      "impact": "Medium",
      "likelihood": "Low",
      "references": [
        "OWASP A04:2021",
        "NIST SP 800-53 SC-28"
      ]
    },
    {
      "component_name": "LB to WS",
      "stride_category": "D",
      "threat_description": "Attacker floods the server with traffic, cau