In [24]:
import json
import re
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate

# Initialize OpenAI LLM
llm = ChatOpenAI(model="gpt-4", temperature=0)

# Prompt to generate JSON DFD
dfd_json_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(
        "You are a system architect. Given the system description, generate a detailed Data Flow Diagram (DFD) in JSON format. "
        "The JSON must have 'processes', 'data_stores', 'data_flows', and 'trust_boundaries' as keys, each with an array of items. "
        "Ensure the JSON is valid and structured."
    ),
    HumanMessagePromptTemplate.from_template("{system_description}")
])
dfd_json_chain = LLMChain(llm=llm, prompt=dfd_json_prompt)

# Prompt for threat modeling
threat_model_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(
        "You are a cybersecurity expert performing threat modeling using STRIDE and scoring threats using DREAD. "
        "You will receive a JSON representation of a DFD. Identify threats per STRIDE categories and score each threat from 1 to 10 using DREAD. "
        "Output the results in a structured text format. Each threat must include a severity score or a summary DREAD score."
    ),
    HumanMessagePromptTemplate.from_template("Here is the JSON DFD:\n{dfd_json}")
])
threat_model_chain = LLMChain(llm=llm, prompt=threat_model_prompt)

# Extract threat names and severity scores
def extract_threat_scores(threat_report_text):
    threat_scores = {}
    lines = threat_report_text.splitlines()
    current_threat = None
    for line in lines:
        line = line.strip()
        # Match lines starting with a threat description
        if line.startswith("1. Threat:") or line.startswith("2. Threat:") or line.startswith("3. Threat:") or line.startswith("Threat:"):
            # Extract the part after "Threat:"
            current_threat = line.split("Threat:")[1].strip()
        # Match line containing "DREAD Score:" or "Summary DREAD Score:"
        elif ("DREAD Score:" in line or "Summary DREAD Score:" in line) and current_threat:
            # Extract the number immediately following "DREAD Score:" or "Summary DREAD Score:"
            match = re.search(r'(?:DREAD Score|Summary DREAD Score):\s*(\d+(\.\d+)?)', line)
            if match:
                score = float(match.group(1))
                threat_scores[current_threat] = round(score)
            current_threat = None  # Reset for next threat
    return threat_scores

# Emoji severity bar visualization
def print_severity_bars(threat_scores):
    severity_icons = ['🟩', '🟨', '🟧', '🟥']  # low, medium, high, critical

    def get_icon(score):
        if score >= 8:
            return severity_icons[3]
        elif score >= 6:
            return severity_icons[2]
        elif score >= 4:
            return severity_icons[1]
        else:
            return severity_icons[0]

    print("\n🔒 Threat Severity (Visual with Emojis):\n")
    for threat, score in threat_scores.items():
        icon = get_icon(score)
        bar = icon * score + '▫️' * (10 - score)
        print(f"{threat:50} | {bar} | {score}/10")

# Main driver function
def automated_security_analysis_json(system_description: str):
    print("📥 Generating JSON Data Flow Diagram...\n")
    dfd_json_str = dfd_json_chain.run(system_description=system_description)

    # Try parsing the generated JSON
    try:
        dfd_json = json.loads(dfd_json_str)
        print(json.dumps(dfd_json, indent=2))
    except json.JSONDecodeError:
        print("⚠️ Warning: Couldn't parse JSON. Raw output below:\n")
        print(dfd_json_str)
        return

    print("\n🛡️ Performing Threat Modeling based on JSON DFD...\n")
    threat_text = threat_model_chain.run(dfd_json=dfd_json_str)
    print(threat_text)

    # Extract scores
    threat_scores = extract_threat_scores(threat_text)

    # Visual representation
    if threat_scores:
        print_severity_bars(threat_scores)
    else:
        print("\n⚠️ Could not extract any threat scores from the analysis output.")

# --- SAMPLE RUN ---
system_description = """
A library management system that allows users to search for books, reserve books, and check out books. 
The system also allows librarians to manage books, manage users, and view reports.
"""

automated_security_analysis_json(system_description)


📥 Generating JSON Data Flow Diagram...

{
  "processes": [
    {
      "id": "P1",
      "name": "Search Books",
      "description": "Allows users to search for books in the library"
    },
    {
      "id": "P2",
      "name": "Reserve Books",
      "description": "Allows users to reserve books"
    },
    {
      "id": "P3",
      "name": "Check Out Books",
      "description": "Allows users to check out books"
    },
    {
      "id": "P4",
      "name": "Manage Books",
      "description": "Allows librarians to manage books in the library"
    },
    {
      "id": "P5",
      "name": "Manage Users",
      "description": "Allows librarians to manage users"
    },
    {
      "id": "P6",
      "name": "View Reports",
      "description": "Allows librarians to view reports"
    }
  ],
  "data_stores": [
    {
      "id": "DS1",
      "name": "Book Database",
      "description": "Stores all the information about the books"
    },
    {
      "id": "DS2",
      "name": "User Database"