In [None]:
from huggingface_hub import login
from tavily import TavilyClient
import os
hf_token = ""
login(hf_token)
os.environ["TAVILY_API_KEY"] = ""

In [None]:
import logging
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.tools import tool


CONFIG = {
    "model_name": "Alibaba-NLP/gte-large-en-v1.5",
    "collection_name": "vulnerabilities",
    "persistent_dir": "./chromadb_part2",
}

def load_vectorstore():
    embeddings = HuggingFaceEmbeddings(
        model_name=CONFIG["model_name"],
        model_kwargs={"trust_remote_code": True}
    )
    
    vectorstore = Chroma(
        persist_directory=CONFIG["persistent_dir"],
        embedding_function=embeddings,
        collection_name=CONFIG["collection_name"]
    )
    
    return vectorstore

def query_vectorstore(vectorstore, query: str, k: int):
    try:
        results = vectorstore.similarity_search(query, k=k)
        print("\nQuery Results:")
        for i, doc in enumerate(results, 1):
            print(f"\nResult {i}:")
            print(f"Content: {doc.page_content[:200]}...")  # Show first 200 characters
            print(f"Metadata: {doc.metadata}")
            print("-" * 80)  # Separator between results
    except Exception as e: 
        print(e)

In [None]:
vectorstore = load_vectorstore()
retriever = vectorstore.as_retriever(k=5)

In [4]:
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults

web_search_tool = TavilySearchResults()

In [5]:
def assign_unique_test_case_numbers(test_cases):
    for index, tc in enumerate(test_cases, start=1):
        tc.name = f"Test Case {index}: {tc.name.split(':', 1)[-1].strip()}"

In [6]:
def verify_test_cases(test_cases):
    seen_ids = set()
    verified_test_cases = []
    for tc in test_cases:
        if tc.id not in seen_ids:
            seen_ids.add(tc.id)
            verified_test_cases.append(tc)
        else:
            print(f"Duplicate test case id found: {tc.id}")
    return verified_test_cases

In [7]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import JsonOutputParser
from langchain_ollama import ChatOllama
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from typing import List


class DocumentRetrievalPrompt(BaseModel):
    keywords: List[str] = Field(description="List of keywords extracted from the attack tree analysis")
    vulnerabilities: List[dict] = Field(description="List of identified vulnerabilities with details")
    query: str = Field(description="A comprehensive query for document retrieval")

output_parser = PydanticOutputParser(pydantic_object=DocumentRetrievalPrompt)

llm = ChatOllama(model="llama3.1:70b", temperature=0)

vulnearbility_prompt = PromptTemplate(
    template="""You are a senior security engineer tasked with performing an in-depth analysis of an attack tree to identify all possible vulnerabilities in the system. Your analysis should be thorough and comprehensive, leaving no stone unturned.

Here is the attack tree in JSON format:

{attack_tree}

Please perform the following tasks:
1. Conduct a detailed, step-by-step analysis of the attack tree, considering all possible attack vectors and their implications.
2. Identify and Describe all potential vulnerabilities in the system, including their severity, potential impact, and possible mitigation strategies.
3. Extract relevant keywords from the attack tree that are crucial for understanding the system's security landscape.
4. Create a comprehensive query for document retrieval that covers all aspects of the attack tree and identified vulnerabilities.

Ensure your analysis is exhaustive and doesn't overlook any potential security risks. Consider both obvious and non-obvious attack paths.

{format_instructions}

Provide your in-depth analysis and the document retrieval prompt in the specified JSON format.""",
    input_variables=["attack_tree"],
    partial_variables={"format_instructions": output_parser.get_format_instructions()},
)


attack_tree_analyzer = vulnearbility_prompt | llm | output_parser

In [8]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import JsonOutputParser
from langchain_ollama import ChatOllama
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from typing import List

class TestCase(BaseModel):
    id: int = Field(description="Unique identifier for the test case")
    name: str = Field(description="Name of the test case")
    description: str = Field(description="Detailed description of what the test case is checking")
    vulnerability_addressed: str = Field(description="The specific vulnerability this test case is addressing")
    setup: str = Field(description="Setup code for the test case")
    test_code: str = Field(description="Actual test code")
    teardown: str = Field(description="Teardown code for the test case")
    expected_result: str = Field(description="Expected result of the test")

class TestCaseSet(BaseModel):
    test_cases: List[TestCase] = Field(description="List of generated detailed test cases")

output_parser = PydanticOutputParser(pydantic_object=TestCaseSet)

test_case_prompt = ChatPromptTemplate.from_template("""
You are an elite security test engineer with extensive experience in creating comprehensive, robust, and detailed Python test suites, specifically targeting vulnerabilities in diverse systems and infrastructures. Your task is to generate an exceptionally thorough set of Python test cases for a system based on a detailed attack tree analysis and additional context documents. The goal is to create a test suite that rigorously validates the security of the system and defends against ALL identified vulnerabilities, no matter how complex or nuanced.

**Attack Tree**:
{attack_tree}

**Analysed Vulnerabilities**:
{attack_tree_analysis}

**Additional Context Documents**:
{context_documents}

### Critical Requirements for the Test Suite:
1. **Vulnerability Focus**: Each test case must be designed to specifically test for vulnerabilities mentioned in the attack tree, focusing on real security issues like improper access control, misconfigurations, insecure data storage, unencrypted data transmission, excessive permissions, and API abuse. Go beyond availability checks to test for security weaknesses.
2. **Alignment with Attack Tree**: Create a separate test case for each attack vector identified in the attack tree analysis. No vulnerability should be left untested. Ensure test cases mirror the attack vectors, addressing both the specific scenarios and potential bypass methods.
3. **Depth in Vulnerability Testing**: For complex vulnerabilities, create multiple test methods to cover various scenarios, edge cases, and potential bypass methods, ensuring thorough coverage. Incorporate multiple stages of attacks (e.g., privilege escalation, lateral movement, denial of service).
4. **Realistic Attack Scenarios**: Use realistic, diverse, and complex data sets in your tests, simulating real-world usage and attack scenarios relevant to the system under test. The test cases should go beyond basic functionality checks to simulate attack patterns like privilege escalation, improper role assignment, and compromised credentials.
5. **Security Assertions**: Implement sophisticated assertions that not only verify functionality but also confirm the absence of security vulnerabilities (e.g., improper access permissions, unencrypted data storage, misconfigured security settings).
6. **System-Specific Vulnerability Testing**: Ensure test cases cover security aspects of the specific components or services of the system, focusing on vulnerability testing rather than just availability. For instance, test improper access control for data storage, unencrypted communication channels, insecure network configurations, etc.
7. **Positive and Negative Tests**: Implement both positive tests (verifying secure configurations and expected behavior) and negative tests (attempting to exploit vulnerabilities) for each vulnerability.
8. **Parameterized Tests**: Where applicable, implement parameterized tests to cover a wide range of inputs efficiently, ensuring scalability in vulnerability testing.
9. **Race Conditions and Timing Attacks**: Include tests for complex vulnerabilities like race conditions, timing attacks, and multi-stage attacks that simulate real-world advanced persistent threat (APT) scenarios.
10. **Error Handling**: Include error-handling tests to ensure that the system behaves securely under various error conditions, such as handling malformed requests or insufficient permissions securely.
11. **Comprehensive Coverage**: Maintain comprehensive coverage of all system components but ensure that the tests are focused on testing security vulnerabilities (e.g., roles with excessive permissions, unencrypted data storage, misconfigured security alerts that fail to detect incidents).
12. **Thorough Documentation**: Add exhaustive comments and docstrings explaining:
    - The purpose of each test method
    - The specific vulnerability or attack vector being tested
    - The expected outcome and why it's secure
    - Any subtle points or non-obvious security implications
13. **Runnable Code**: Provide Python code for each test case that is immediately runnable. Ensure the use of the `unittest` framework and appropriate mocking techniques where applicable, but move beyond basic mocking to simulate real-world attack scenarios.
14. **Coverage Report**: After generating the test cases, provide a coverage report explaining how each identified vulnerability is addressed by the test cases.

### Example Format:
Each test case should include:
- **Setup**: Any necessary setup, such as configuring system components, setting user roles and permissions, or initializing data stores.
- **Test Method**: The test logic that simulates the attack and checks for the presence of the vulnerability. The method should include positive tests (for secure configurations) and negative tests (for attempted exploits).
- **Teardown**: Any necessary cleanup to ensure tests don't affect each other.
- **Expected Result**: The expected behavior when the system is secure and the vulnerability is absent.
                                                    
Each test case should include an 'id' field that is a unique integer.
                                                    
{format_instructions}

Remember to provide actual Python code for each test case, not just descriptions, and ensure ALL vulnerabilities are covered.
""")

llm = ChatOllama(
    model="llama3.1:70b",
    temperature=0,
)

rag_testcase_generator = test_case_prompt | llm | output_parser

In [9]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import JsonOutputParser

llm = ChatOllama(model="llama3.1:70b", format="json", temperature=0)


grader_prompt = PromptTemplate(
    template="""You are an expert security analyst grading the relevance of a retrieved document to vulnerabilities identified in an attack tree analysis.

Analyzed Vulnerability:
{vulnerability}

Retrieved Document:
{document}

Grade the document's relevance to the vulnerability based on the following criteria:
1. The document discusses the specific vulnerability or closely related security issues.
2. The document provides relevant information for understanding or mitigating the vulnerability.
3. The document contains example code or test cases that could be adapted to test for this vulnerability.

Provide a binary score as a JSON with a single key 'score':
- Use 'yes' if the document is relevant and meets at least two of the above criteria.
- Use 'no' if the document is not relevant or meets fewer than two criteria.

Return only the JSON object with no preamble or explanation.""",
    input_variables=["vulnerability", "document"],
)

retrieval_grader = grader_prompt | llm | JsonOutputParser()

In [10]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser

# Define the output parser if not already defined
output_parser = PydanticOutputParser(pydantic_object=TestCaseSet)

llm = ChatOllama(model="llama3.1:70b", format="json", temperature=0)

regenerate_test_case_prompt = ChatPromptTemplate.from_template("""
You are an elite security test engineer with extensive experience in creating comprehensive and robust test suites across various systems and infrastructures. Your critical task is to **modify the existing test cases** based on the improvement suggestions provided, and **add new test cases** for any missing vulnerabilities.

### Existing Test Cases:
{existing_test_cases}

### Improvement Suggestions:
{improvement_suggestions}

### Missing Vulnerabilities:
{missing_vulnerabilities}

### Instructions:
1. **Modify the existing test cases** to incorporate all improvement suggestions. Only make changes where improvements are suggested; retain other content.
2. For each **missing vulnerability**, **create a new test case** that exactly addresses the vulnerability.
3. Ensure that all test cases use appropriate and actual code relevant to the system under test, utilizing standard libraries or APIs suitable for that system.
4. Include all necessary **setup**, including required imports and initialization of system components or services if needed.
5. The test code must be **complete, runnable Python code**. Do not use pseudocode or placeholders.
6. Follow **best practices** for the system or domain you are testing, and use appropriate methods and calls.
7. Each test case should demonstrate both the **vulnerable state and the secure state**.
8. Use **assert statements** to clearly indicate what constitutes a pass or fail condition.
9. Each test case should include an 'id' field that is a unique integer.
{format_instructions}
""")

regenerate_test_case_generator = regenerate_test_case_prompt | llm | output_parser


In [None]:
from typing_extensions import TypedDict, List
from langgraph.graph import START, END, StateGraph
from langchain.vectorstores import VectorStore
import json
from langchain_core.messages import AIMessage
from IPython.display import Image, display
from langchain.schema import Document
from requests.exceptions import HTTPError
import re


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        attack_tree: JSON representation of the attack tree
        analysis: Result of the vulnerability analysis
        retrieval_prompt: Generated prompt for document retrieval
        documents: List of retrieved documents
        test_cases: Generated test cases
        steps: List of steps taken
        document_grades: List of document grades
        search_needed: Flag to determine if web search is needed
        alignment_check: Alignment check results
        regeneration_attempts: Number of regeneration attempts
    """
    attack_tree: str
    analysis: dict
    retrieval_prompt: str
    documents: List[Document]
    test_cases: List[dict]
    steps: List[str]
    document_grades: List[dict]
    search_needed: bool
    alignment_check: dict
    regeneration_attempts: int
    vulnerabilities: dict


def analyze_attack_tree(state):
    attack_tree = state["attack_tree"]
    result = attack_tree_analyzer.invoke({"attack_tree": attack_tree})
    
    steps = state["steps"]
    steps.append("analyze_attack_tree")

    # Ensure that vulnerabilities are correctly extracted
    vulnerabilities = result.vulnerabilities if hasattr(result, 'vulnerabilities') else []
    
    # Update the state with the extracted information
    state.update({
        "analysis": {
            "keywords": result.keywords,
            "vulnerabilities": vulnerabilities,
            "query": result.query
        },
        "retrieval_prompt": result.query,
        "steps": steps,
    })

    return state



def retrieve_documents(state):
    """
    Retrieve documents based on the vulnerabilities.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updated state with retrieved documents
    """
    vulnerabilities = state["analysis"]["vulnerabilities"]
    steps = state["steps"]
    steps.append("retrieve_documents")
    documents = []

    for vulnerability in vulnerabilities:
        print(vulnerability)
        try:
            query = f"{vulnerability['name']}: {vulnerability['description']}"
        except:
            query = f"{vulnerability['name']}"
            
        # print(f"Retrieving documents for vulnerability: {vulnerability['name']}")
        # docs = vectorstore.similarity_search(query, k=5)  # Retrieve 3 documents per vulnerability
        # print(f"Retrieved {len(docs)} documents for {vulnerability['name']}")
        docs = retriever.get_relevant_documents(query)  
        #print(f"Retrieved {len(docs)} documents for {vulnerability['name']}")
        documents.extend(docs)

    state["documents"] = documents
    state["steps"] = steps
    return state


def grade_documents(state):
    vulnerabilities = state["analysis"]["vulnerabilities"]
    documents = state["documents"]
    steps = state["steps"]
    steps.append("grade_documents")
    filtered_docs = []
    document_grades = []
    search_needed = False

    for vulnerability in vulnerabilities:
        #print(f"Grading documents for vulnerability: {vulnerability['name']}")
        relevant_docs = []
        for doc in documents:
            try:
                score = retrieval_grader.invoke(
                    {"vulnerability": json.dumps(vulnerability), "document": doc.page_content}
                )
                grade = score["score"]
                document_grades.append({
                    "vulnerability": vulnerability['name'],
                    "document_content": doc.page_content[:100] + "...",
                    "grade": grade
                })
                if grade.lower() == 'yes':
                    relevant_docs.append(doc)
            except Exception as e:
                print(f"Error grading document: {e}")

        #print(f"Found {len(relevant_docs)} relevant documents for {vulnerability['name']}")
        if len(relevant_docs) < 3:  # Adjust this threshold as needed
            search_needed = True
            #print(f"Not enough relevant documents for {vulnerability['name']}. Web search may be needed.")
        
        filtered_docs.extend(relevant_docs)

    state["documents"] = filtered_docs
    state["search_needed"] = search_needed
    state["steps"] = steps
    state["document_grades"] = document_grades
    return state



def web_search(state):
    analysis = state["analysis"]
    documents = state.get("documents", [])
    steps = state["steps"]
    steps.append("web_search")
    document_grades = state.get("document_grades", [])

    for vulnerability in analysis["vulnerabilities"]:
        # Ensure 'name' and 'description' exist in the vulnerability dictionary
        name = vulnerability.get('name', 'Unknown Vulnerability')
        description = vulnerability.get('description', 'No description provided')

        # Construct a search-friendly query string
        query = f"{name} - {description}"
        print(f"Constructed Web Search Query: {query}")  # Debugging

        try:
            # Invoke the web search tool (simulated or actual web search)
            web_results = web_search_tool.invoke({"query": query})
            new_docs = [Document(page_content=d["content"], metadata={"url": d["url"]}) for d in web_results]
            documents.extend(new_docs)

            # Grade new documents
            for doc in new_docs:
                score = retrieval_grader.invoke(
                    {"vulnerability": json.dumps(vulnerability), "document": doc.page_content}
                )
                grade = score["score"]
                document_grades.append({
                    "vulnerability": str(vulnerability),
                    "document_content": doc.page_content[:100] + "...",
                    "grade": grade
                })

        except Exception as e:
            print(f"Error during web search or document grading: {e}")

    return {
        **state,  # Include all existing state
        "documents": documents,
        "steps": steps,
        "document_grades": document_grades
    }



def decide_to_generate(state):
    """
    Determines whether to generate test cases or perform a web search.

    Args:
        state (dict): The current graph state

    Returns:
        str: Decision for next node to call
    """
    search_needed = state.get("search_needed", False)
    if search_needed:
        return "web_search"
    else:
        return "generate_test_cases"

def generate_test_cases(state):
    vulnerabilities = state["analysis"]["vulnerabilities"]
    documents = state["documents"]
    steps = state["steps"]
    steps.append("generate_test_cases")

    result = rag_testcase_generator.invoke({
        "attack_tree_analysis": json.dumps(state["analysis"]),
        "context_documents": "\n".join([doc.page_content for doc in documents]),
        "attack_tree": state["attack_tree"],
        "format_instructions": output_parser.get_format_instructions()
    })

    state["test_cases"] = result.test_cases
    return state

def check_test_cases(state):
    attack_tree = json.loads(state["attack_tree"])
    test_cases = state["test_cases"]
    steps = state["steps"]
    steps.append("check_test_cases")

    alignment_prompt = f"""
    You are an expert security analyst with a critical eye for detail. Your task is to rigorously check if the generated test cases align with the attack tree, are complete, and are of high quality.

    Attack Tree:
    {json.dumps(attack_tree, indent=2)}

    Generated Test Cases:
    {json.dumps([tc.dict() for tc in test_cases], indent=2)}

    Perform a thorough analysis of the test cases and provide the following:
    1. Alignment: Are the test cases properly aligned with the vulnerabilities identified in the attack tree? Be extremely critical.
    2. Completeness: Do the test cases cover ALL vulnerabilities mentioned in the attack tree? List any that are missing or inadequately covered.
    3. Runnability: Is the code in the test cases runnable in Python? Are there any missing imports, setup steps, or other issues that would prevent immediate execution?
    4. Quality: Assess the quality of each test case. Are they thorough? Do they actually test what they claim to test?
    5. Improvements: Suggest specific, detailed improvements for each test case that falls short in any way.

    Provide your analysis in a structured JSON format with the following keys:
    - alignment_score (0-100, be very strict)
    - completeness_score (0-100, be very strict)
    - runnability_score (0-100, be very strict)
    - quality_score (0-100, be very strict)
    - missing_vulnerabilities (list of vulnerabilities not covered or inadequately covered)
    - improvement_suggestions (list of objects, each containing:
        - test_case_name: the name of the test case that needs improvement
        - suggestions: list of specific, detailed suggestions for improving that test case)

    Be extremely critical in your assessment. We need to ensure these test cases are of the highest possible quality.
    """

    alignment_check = llm.invoke(alignment_prompt)

    if isinstance(alignment_check, AIMessage):
        alignment_check_content = alignment_check.content
    else:
        alignment_check_content = str(alignment_check)

    try:
        alignment_result = json.loads(alignment_check_content)
    except json.JSONDecodeError:
        print("Warning: Could not parse JSON from LLM response. Using default values.")
        alignment_result = {
            "alignment_score": 0,
            "completeness_score": 0,
            "runnability_score": 0,
            "quality_score": 0,
            "missing_vulnerabilities": ["Could not determine"],
            "improvement_suggestions": [
                {
                    "test_case_name": "Unknown",
                    "suggestions": ["Regenerate all test cases due to parsing error"]
                }
            ]
        }

    state["alignment_check"] = alignment_result
    state["steps"] = steps
    return state

def regenerate_test_cases(state):
    alignment_check = state["alignment_check"]
    steps = state["steps"]
    steps.append("regenerate_test_cases")
    regeneration_attempts = state.get("regeneration_attempts", 0)
    regeneration_attempts += 1
    state["regeneration_attempts"] = regeneration_attempts

    if regeneration_attempts >= 10:
        #print("Maximum regeneration attempts reached, ending workflow.")
        return state

    missing_vulnerabilities = alignment_check.get("missing_vulnerabilities", [])
    improvement_suggestions = alignment_check.get("improvement_suggestions", [])

    # Map improvement suggestions to test cases
    improvements_map = {improv["test_case_name"]: improv["suggestions"] for improv in improvement_suggestions}

    # Extract test cases that need to be modified
    test_cases_to_modify = []
    for tc in state["test_cases"]:
        if tc.name in improvements_map:
            test_cases_to_modify.append(tc)

    # Prepare the prompt inputs
    prompt_inputs = {
        "test_cases_to_modify": json.dumps([tc.dict() for tc in test_cases_to_modify], indent=2),
        "improvements_map": json.dumps(improvements_map, indent=2),
        "missing_vulnerabilities": json.dumps(missing_vulnerabilities, indent=2),
        "format_instructions": output_parser.get_format_instructions(),
    }

    # Regeneration prompt focusing on modifying specific test cases
    regenerate_prompt = f"""
    You are an elite security test engineer with extensive experience in creating comprehensive and robust test suites across various systems and infrastructures. Your critical task is to **modify specific test cases** based on the provided improvement suggestions, and **add new test cases** for any missing vulnerabilities.

    ### Test Cases to Modify:
    {prompt_inputs['test_cases_to_modify']}

    ### Improvement Suggestions:
    {prompt_inputs['improvements_map']}

    ### Missing Vulnerabilities:
    {prompt_inputs['missing_vulnerabilities']}

    ### Instructions:
    1. **Modify the test cases listed above** to incorporate all improvement suggestions specific to each test case. Only make changes where improvements are suggested; retain other content.
    2. For each **missing vulnerability**, **create a new test case** that exactly addresses the vulnerability.
    3. Ensure that all test cases use appropriate and actual code relevant to the system under test, utilizing standard libraries or APIs suitable for that system.
    4. Include all necessary **setup**, including required imports and initialization of system components or services if needed.
    5. The test code must be **complete, runnable Python code**. Do not use pseudocode or placeholders.
    6. Follow **best practices** for the system or domain you are testing, and use appropriate methods and calls.
    7. Each test case should demonstrate both the **vulnerable state and the secure state**.
    8. Use **assert statements** to clearly indicate what constitutes a pass or fail condition.

    {prompt_inputs['format_instructions']}
    """

    # Invoke the LLM directly with the regenerate_prompt
    llm_response = llm.invoke(regenerate_prompt)

    # Extract the content from the AIMessage
    if isinstance(llm_response, AIMessage):
        llm_content = llm_response.content
    else:
        llm_content = str(llm_response)

    # Parse the LLM output using the output parser
    try:
        parsed_output = output_parser.parse(llm_content)
        modified_and_new_test_cases = parsed_output.test_cases
        #print(f"Successfully modified/added {len(modified_and_new_test_cases)} test cases.")

        # Create a mapping of test case names to test cases
        existing_test_cases_map = {tc.id: tc for tc in state["test_cases"]}

        # Update the modified test cases in the existing test cases
        for tc in modified_and_new_test_cases:
            existing_test_cases_map[tc.name] = tc

        # Update the state with the combined test cases
        state["test_cases"] = list(existing_test_cases_map.values())
        assign_unique_test_case_numbers(state["test_cases"])


        # Verify that all missing vulnerabilities are addressed
        addressed_vulnerabilities = set(tc.vulnerability_addressed for tc in state["test_cases"])
        still_missing = set(missing_vulnerabilities) - addressed_vulnerabilities

        # if still_missing:
        #     print(f"Warning: The following vulnerabilities are still not addressed: {still_missing}")
        # else:
        #     print("All previously missing vulnerabilities have been addressed.")

    except Exception as e:
        print(f"Error parsing LLM output: {e}")
        print("LLM Output:", llm_content)
        print("Regeneration failed due to parsing error.")
        # Optionally, set alignment scores to force termination
        state["alignment_check"] = {
            "alignment_score": 0,
            "completeness_score": 0,
            "runnability_score": 0,
            "quality_score": 0,
            "missing_vulnerabilities": missing_vulnerabilities,
            "improvement_suggestions": improvement_suggestions
        }

    state["steps"] = steps
    return state


def decide_to_stop_or_regenerate(state):
    alignment_check = state["alignment_check"]
    regeneration_attempts = state.get("regeneration_attempts", 0)
    #print(f"Regeneration Attempts: {regeneration_attempts}")

    if not alignment_check.get("missing_vulnerabilities") and not alignment_check.get("improvement_suggestions") and (
        alignment_check.get("alignment_score", 0) > 90 and 
        alignment_check.get("completeness_score", 0) > 90 and 
        alignment_check.get("runnability_score", 0) > 90
    ):
        #print("All vulnerabilities addressed and alignment scores highly satisfactory, ending workflow.")
        return END
    elif regeneration_attempts >= 10:
        #print("Maximum regeneration attempts reached, ending workflow.")
        return END
    else:
        #print("Regenerating test cases.")
        return "regenerate_test_cases"

# Define the workflow
workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("analyze_attack_tree", analyze_attack_tree)
workflow.add_node("retrieve_documents", retrieve_documents)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("web_search", web_search)
workflow.add_node("generate_test_cases", generate_test_cases)
workflow.add_node("check_test_cases", check_test_cases)
workflow.add_node("regenerate_test_cases", regenerate_test_cases)

# Build the graph
workflow.add_edge(START, "analyze_attack_tree")
workflow.add_edge("analyze_attack_tree", "retrieve_documents")
workflow.add_edge("retrieve_documents", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "web_search": "web_search",
        "generate_test_cases": "generate_test_cases",
    },
)
workflow.add_edge("web_search", "generate_test_cases")
workflow.add_edge("generate_test_cases", "check_test_cases")
workflow.add_conditional_edges(
    "check_test_cases",
    decide_to_stop_or_regenerate,
    {
        END: END,
        "regenerate_test_cases": "regenerate_test_cases"
    }
)
workflow.add_edge("regenerate_test_cases", "check_test_cases")

custom_graph = workflow.compile()
display(Image(custom_graph.get_graph(xray=True).draw_mermaid_png()))


In [None]:
import os
import json
from tqdm import tqdm

input_folder = './dataset/output_json/'
output_folder = './LLama_output/'
error_folder = './error_files/'  # Folder to store files that caused errors

# Ensure output and error folders exist
os.makedirs(output_folder, exist_ok=True)
os.makedirs(error_folder, exist_ok=True)

# Get all JSON files in the input folder
attack_tree_files = [f for f in os.listdir(input_folder) if f.endswith('.json')]

# List to keep track of files that had errors
error_files = []

# Process each attack tree file
for filename in tqdm(attack_tree_files, desc="Processing Files"):
    input_file_path = os.path.join(input_folder, filename)
    output_file_name = os.path.splitext(filename)[0] + '_test_cases.json'
    output_file_path = os.path.join(output_folder, output_file_name)

    try:
        # Load the attack tree
        with open(input_file_path, 'r') as f:
            attack_tree = json.load(f)

        # Run the workflow
        result = custom_graph.invoke(
            {
                "attack_tree": json.dumps(attack_tree),
                "steps": []
            },
            config={"recursion_limit": 50}
        )

        # Get the test cases
        test_cases = verify_test_cases(result.get('test_cases', []))

        # Save test cases to the output file
        with open(output_file_path, 'w') as f_out:
            json.dump([tc.dict() for tc in test_cases], f_out, indent=2)

        # Optionally, also save the alignment check results
        alignment_check = result.get('alignment_check', {})
        alignment_check_file_name = os.path.splitext(filename)[0] + '_alignment_check.json'
        alignment_check_file_path = os.path.join(output_folder, alignment_check_file_name)
        with open(alignment_check_file_path, 'w') as f_align:
            json.dump(alignment_check, f_align, indent=2)

        # Print status
        print(f"Processed {filename}, saved test cases to {output_file_name}")

    except Exception as e:
        print(f"Error processing file {filename}: {e}")
        error_files.append(filename)
        # Move the problematic file to the error folder for later inspection
        os.rename(input_file_path, os.path.join(error_folder, filename))
        continue  # Continue to the next file

# After processing all files, write the list of error files to a log
error_log_path = os.path.join(output_folder, 'error_files.txt')
with open(error_log_path, 'w') as f_error:
    for error_file in error_files:
        f_error.write(f"{error_file}\n")

print(f"Processing complete. {len(error_files)} files had errors.")