In [1]:
from langchain_neo4j.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_neo4j.vectorstores.neo4j_vector import Neo4jVector

import os

## Init OPENAI_API_KEY

In [2]:
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
llm = ChatOpenAI(
    model="gpt-4o",
    temperature=0,
    api_key=OPENAI_API_KEY
)

## Init NEO4J Database

In [3]:
neo4j_uri = os.environ.get("NEO4J_URI")
neo4j_username = os.environ.get("NEO4J_USERNAME")
neo4j_password = os.environ.get("NEO4J_PASSWORD_ICS")

print(neo4j_uri, neo4j_username, neo4j_password)

bolt://localhost:7687 neo4j attackmitreics


In [None]:
from langchain_neo4j import Neo4jGraph
graph = Neo4jGraph(
    url=neo4j_uri,
    username=neo4j_username,
    password=neo4j_password
)

In [6]:
schema_raw = graph.get_schema

## Cyper Generation Template

In [7]:
cypher_generation_template = """
You are an expert Neo4j Cypher translator who converts English to Cypher based on the Neo4j Schema provided, following the instructions below:
        1. Generate Cypher query compatible ONLY for Neo4j Version 5
        2. Do not use EXISTS, SIZE, HAVING keywords in the cypher. Use alias when using the WITH keyword
        3. Use only Nodes and relationships mentioned in the schema
        5. Never use relationships that are not mentioned in the given schema
        6. For all node labels and relationship types, add namespace prefix `ns0__` before the actual label or relationship type. E.g., `MATCH (n:ns0__NodeLabel)-[:ns0__RelationshipType]->(m:ns0__NodeLabel)`.
        7. Node properties with `created`, `description`, `identifier`, `modified`, `title` and `version`, add prefix `ns1__` instead. E.g., `MATCH (n:ns0__NodeLabel) RETURN n.ns1__title AS Title`.
        8. Always do a case-insensitive and fuzzy search for any properties related search. Eg: to search for a Tactic, use `toLower(Tactic.ns1__title) contains 'persistence'`.
        9. Always assign a meaningful name to every node and relationship in the MATCH clause
        10. Never return components not explicitly named in the MATCH clause.
        11. In the RETURN clause, include all named components (nodes, relationships, or properties) to ensure consistency and understanding.
        12. Always return all the nodes used in the MATCH clause to provide complete information to the user.
        13. When counting distinct items that come from an `OPTIONAL MATCH`, prefer to `collect()` them first and then use `size()` on the collected list to avoid warnings about null values. For example, instead of `count(DISTINCT optional_item)`, use `WITH main_node, collect(DISTINCT optional_item) AS items` and then in the `RETURN` clause use `size(items) AS itemCount`.
        14. To create unique pairs of nodes for comparison (e.g., for similarity calculations), use the `elementId()` function instead of the deprecated `id()` function. For example: `WHERE elementId(node1) < elementId(node2)`.
        15. use `toLower()` function to ensure case-insensitive comparisons for string properties.

Schema:
{schema}

Note: 
Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything other than
for you to construct a Cypher statement. Do not include any text except
the generated Cypher statement. Make sure the direction of the relationship is
correct in your queries. Make sure you alias both entities and relationships
properly. Do not run any queries that would add to or delete from
the database. Make sure to alias all statements that follow as with
statement

In Cypher, you can alias nodes and relationships, but not entire pattern matches using AS directly after a MATCH clause.If you want to alias entire patterns or results of more complex expressions, that should be done in the RETURN clause, not the MATCH clause.
If you want to include any specific properties from these nodes in your results, you can add them to your RETURN statement.

Examples : 

1. Which techniques are commonly used by at least 3 different threat groups?
MATCH (g:ns0__Group)-[:ns0__usesTechnique]->(t:ns0__Technique)
WITH t, count(g) as groupCount
WHERE groupCount >= 3
MATCH (g:ns0__Group)-[:ns0__usesTechnique]->(t)
RETURN t.ns1__title as CommonTechnique, t.ns1__identifier as TechniqueID, 
       groupCount as NumberOfGroups,
       collect(g.ns1__title) as Groups
ORDER BY groupCount DESC

2. Find tactical areas where we have the most significant defensive gaps by identifying tactics that have many techniques but few mitigations, and rank them by coverage percentage!

MATCH (tactic:ns0__Tactic)<-[:ns0__accomplishesTactic]-(technique:ns0__Technique)
WITH tactic, collect(technique) as techniques, count(technique) as techniqueCount
UNWIND techniques as technique
OPTIONAL MATCH (mitigation:ns0__Mitigation)-[:ns0__preventsTechnique]->(technique)
WITH tactic, techniqueCount, technique, count(mitigation) > 0 as hasMitigation

WITH tactic, techniqueCount, 
     sum(CASE WHEN hasMitigation THEN 1 ELSE 0 END) as mitigatedTechniques,
     collect(CASE WHEN NOT hasMitigation THEN technique.ns1__title ELSE NULL END) as unmitigatedTechniques

WITH tactic, techniqueCount, mitigatedTechniques,
     [x IN unmitigatedTechniques WHERE x IS NOT NULL] as filteredUnmitigatedTechniques,
     (toFloat(mitigatedTechniques) / techniqueCount * 100) as coveragePercentage

RETURN tactic.ns1__title as Tactic,
       tactic.ns1__identifier as TacticID,
       techniqueCount as TotalTechniques,
       mitigatedTechniques as MitigatedTechniques,
       techniqueCount - mitigatedTechniques as UnmitigatedTechniqueCount,
       toInteger(coveragePercentage) as CoveragePercentage,
       CASE 
         WHEN coveragePercentage < 30 THEN "CRITICAL" 
         WHEN coveragePercentage < 60 THEN "HIGH" 
         WHEN coveragePercentage < 80 THEN "MEDIUM"
         ELSE "LOW"
       END as RiskLevel,
       filteredUnmitigatedTechniques as UnmitigatedTechniques
ORDER BY coveragePercentage ASC, techniqueCount DESC


The question is:
{question}

"""

In [8]:
from langchain.prompts import PromptTemplate
# from langchain.prompts.prompt import PromptTemplate
cyper_generation_prompt = PromptTemplate(
    template=cypher_generation_template,
    input_variables=["schema","question"]
)


In [9]:
qa_template = """
You are an assistant that takes the results from a Neo4j Cypher query and forms a human-readable response. The query results section contains the results of a Cypher query that was generated based on a user's natural language question. The provided information is authoritative; you must never question it or use your internal knowledge to alter it. Make the answer sound like a response to the question.
Final answer should be easily readable and structured.
Query Results:
{context}

Question: {question}
If the provided information is empty, respond by stating that you don't know the answer. Empty information is indicated by: []
If the information is not empty, you must provide an answer using the results. If the question involves a time duration, assume the query results are in units of days unless specified otherwise.
Never state that you lack sufficient information if data is present in the query results. Always utilize the data provided.
Helpful Answer:
"""

In [10]:
qa_generation_prompt = PromptTemplate(
    template=qa_template,
    input_variables=["context", "question"]
)

## QA Chain

In [9]:
cypher_chain = GraphCypherQAChain.from_llm(
    top_k=10,
    graph=graph,
    verbose=True,
    validate_cypher=True,
    return_intermediate_steps=True,
    cypher_prompt=cyper_generation_prompt,
    qa_prompt=qa_generation_prompt,
    qa_llm=ChatOpenAI(model="gpt-3.5-turbo", temperature=0),
    cypher_llm=ChatOpenAI(model="gpt-4o-mini", temperature=0),
    allow_dangerous_requests=True,
    use_function_response=True
)

def query_cypher(question: str):
    """
    Function to query the cypher chain with a question.
    """
    response = cypher_chain.invoke(question)
    return response

In [10]:
question1 = "Show me all techniques related to 'SQL Injection' and their descriptions."
response = query_cypher(question1)



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (technique:ns0__Technique)
WHERE toLower(technique.ns1__title) contains 'sql injection'
RETURN technique.ns1__title AS TechniqueTitle, technique.ns1__description AS TechniqueDescription[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


In [21]:
question2 = "Find the pair of threat groups that have the most similar attack patterns by calculating the Jaccard similarity of the techniques they use. Show the two groups and their similarity score."
response = query_cypher(question2)



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (group1:ns0__Group)-[:ns0__usesTechnique]->(technique:ns0__Technique)<-[:ns0__usesTechnique]-(group2:ns0__Group)
WHERE elementId(group1) < elementId(group2)
WITH group1, group2, collect(technique) AS sharedTechniques
MATCH (group1)-[:ns0__usesTechnique]->(technique1:ns0__Technique)
WITH group1, group2, sharedTechniques, collect(technique1) AS group1Techniques
MATCH (group2)-[:ns0__usesTechnique]->(technique2:ns0__Technique)
WITH group1, group2, sharedTechniques, group1Techniques, collect(technique2) AS group2Techniques

WITH group1, group2, size(sharedTechniques) AS intersectionSize, 
     size(group1Techniques) + size(group2Techniques) - size(sharedTechniques) AS unionSize

WITH group1, group2, 
     CASE WHEN unionSize > 0 THEN toFloat(intersectionSize) / unionSize ELSE 0 END AS jaccardSimilarity

RETURN group1.ns1__title AS Group1, 
       group2.ns1__title AS Group2, 
       jaccardSimi

In [13]:
question3 = "Rank all tactics based on the total number of unique mitigations available for all techniques within that tactic. Show the tactic with the most mitigations first."
response = cypher_chain.invoke(question3)



[1m> Entering new GraphCypherQAChain chain...[0m




Generated Cypher:
[32;1m[1;3mcypher
MATCH (tactic:ns0__Tactic)<-[:ns0__accomplishesTactic]-(technique:ns0__Technique)

OPTIONAL MATCH (mitigation:ns0__Mitigation)<-[:ns0__hasMitigation]-(technique)

WITH tactic, collect(DISTINCT mitigation) AS uniqueMitigations

RETURN tactic.ns1__title AS Tactic, 
       tactic.ns1__identifier AS TacticID, 
       size(uniqueMitigations) AS TotalUniqueMitigations

ORDER BY TotalUniqueMitigations DESC
[0m
Full Context:
[32;1m[1;3m[{'Tactic': 'Initial Access', 'TacticID': 'TA0108', 'TotalUniqueMitigations': 32}, {'Tactic': 'Inhibit Response Function', 'TacticID': 'TA0107', 'TotalUniqueMitigations': 26}, {'Tactic': 'Lateral Movement', 'TacticID': 'TA0109', 'TotalUniqueMitigations': 24}, {'Tactic': 'Persistence', 'TacticID': 'TA0110', 'TotalUniqueMitigations': 22}, {'Tactic': 'Collection', 'TacticID': 'TA0100', 'TotalUniqueMitigations': 22}, {'Tactic': 'Execution', 'TacticID': 'TA0104', 'TotalUniqueMitigations': 18}, {'Tactic': 'Evasion', 'TacticID':

In [14]:
print(response.get("result"))

1. Initial Access: 32 unique mitigations
2. Inhibit Response Function: 26 unique mitigations
3. Lateral Movement: 24 unique mitigations
4. Persistence: 22 unique mitigations
5. Collection: 22 unique mitigations
6. Execution: 18 unique mitigations
7. Evasion: 16 unique mitigations
8. Impair Process Control: 14 unique mitigations
9. Impact: 11 unique mitigations
10. Discovery: 7 unique mitigations
