In [1]:
# Get the environment variables
from dotenv import load_dotenv
import os

load_dotenv()
print(f"OpenAI API key => {len(os.getenv('OPENAI_API_KEY')) * '#'}")
print(f"Langchain API Key => {len(os.getenv('LANGCHAIN_API_KEY')) * '#'}")
print(f"Langchain project name => {os.getenv('LANGCHAIN_PROJECT_NAME')}")
print(f"Langchain endpoint => {os.getenv('LANGCHAIN_ENDPOINT')}")
print(f"Langchain tracking => {os.getenv('LANGCHAIN_TRACING_V2')}")

print(f"OpenSearch HOST => {os.getenv('OPENSEARCH_HOST').split('.')[0]}.XXX.XXX.XXX")
print(f"OpenSearch PORT => {os.getenv('OPENSEARCH_PORT')}")
print(f"OpenSearch index name => {os.getenv('OPENSEARCH_INDEX_NAME')}")
print(f"OpenSearch account ID => {os.getenv('OPENSEARCH_ACCOUNT_ID')}")
print(f"OpenSearch account password => {len(os.getenv('OPENSEARCH_ACCOUNT_PASSWORD')) * '#'}")

OpenAI API key => ########################################################
Langchain API Key => ###################################################
Langchain project name => MDR-THREATGHOUL-CTI-RAG-MITRE
Langchain endpoint => https://api.smith.langchain.com
Langchain tracking => true
OpenSearch HOST => 15.XXX.XXX.XXX
OpenSearch PORT => 9200
OpenSearch index name => mdr_threatghoul_cti_rag_mitre
OpenSearch account ID => admin
OpenSearch account password => ###################


In [2]:
# Start Langsmith tracking
from langsmith import traceable
from langchain.callbacks.tracers import LangChainTracer
from langchain.callbacks.manager import CallbackManager

tracer = LangChainTracer(
    project_name=os.getenv('LANGCHAIN_PROJECT_NAME'),
)
callback_manager = CallbackManager([tracer])

# Wrap OpenAI functions
from langchain_openai import OpenAI
openai_llm = OpenAI(temperature=0, callback_manager=callback_manager)

In [3]:
# Import necessary libraries
from MITREAttackScrapper.cti.groups import MITREAttackCTIGroups
from langchain_core.documents import Document
from typing import List, Dict
from tqdm import tqdm

# Initialize an empty list to store documents
documents: List[Document] = []

# Get the list of MITRE ATT&CK Group IDs
attack_group_id_list: List[str] = [attack_group['id'] for attack_group in MITREAttackCTIGroups.get_list()]

# Iterate over a subset of the attack group IDs to fetch detailed information
for attack_group_id in tqdm(attack_group_id_list):
    attack_group_detail = MITREAttackCTIGroups.get(attack_group_id)
    
    # Extract the details of the threat group
    attack_group_id: str                                    = attack_group_detail['id']
    attack_group_name: str                                  = attack_group_detail['name']
    attack_group_description: str                           = attack_group_detail['description']
    attack_group_contributors: List[str]                    = attack_group_detail.get('contributors', [])
    attack_group_version: str                               = attack_group_detail.get('version', 'N/A')
    attack_group_created: str                               = attack_group_detail.get('created', 'N/A')
    attack_group_last_modified: str                         = attack_group_detail.get('last_modified', 'N/A')
    attack_group_url: str                                   = attack_group_detail['url']
    attack_group_associated_groups: List[Dict[str, str]]    = attack_group_detail.get('associated_group_descriptions', [])
    attack_group_techniques: List[Dict[str, str]]           = attack_group_detail.get('techniques_used', [])
    attack_group_softwares: List[Dict[str, str]]            = attack_group_detail.get('software', [])
    attack_group_references: Dict[int, Dict[str, str]]      = attack_group_detail.get('references', {})

    # Create a description related to the threat group
    description = f"{attack_group_name} (MITRE ATT&CK Group ID: {attack_group_id}) is a threat group that {attack_group_description}.\n"
    
    # Add contributors
    if attack_group_contributors:
        description += f"\nContributors: {', '.join(attack_group_contributors)}\n"
    
    # Add creation and modification dates
    description += f"\nVersion: {attack_group_version}\nCreated: {attack_group_created}\nLast Modified: {attack_group_last_modified}\n"
    
    # Add associated groups
    if attack_group_associated_groups:
        description += "\nAssociated Groups:\n"
        for assoc_group in attack_group_associated_groups:
            description += f"- {assoc_group['name']}: {assoc_group['description']}\n"
    else:
        description += "\nAssociated Groups: None\n"
    
    # Add techniques used
    if attack_group_techniques:
        description += "\nTechniques Used:\n"
        for technique in attack_group_techniques:
            description += (
                f"- {technique['main_technique_name']} (ID: {technique['main_technique_id']}): "
                f"{technique['use']}"
            )
            if 'sub_technique_id' in technique and 'sub_technique_name' in technique:
                description += f" (Sub-technique: {technique['sub_technique_name']} (ID: {technique['sub_technique_id']}))"
            description += "\n"
    else:
        description += "\nTechniques Used: None\n"

    # Add software used
    if attack_group_softwares:
        description += "\nSoftware Used:\n"
        for software in attack_group_softwares:
            description += f"- {software['name']} (ID: {software['id']})"
            if 'techniques' in software:
                description += ", associated techniques: "
                for technique in software['techniques']:
                    description += f"{technique['name']} (ID: {technique['url'].split('/')[-1]}), "
            description += "\n"
    else:
        description += "\nSoftware Used: None\n"

    # Create a document and add it to the list
    document = Document(page_content=description, 
                        metadata={"id": attack_group_id, 
                                  "name": attack_group_name,
                                  "url": attack_group_url,
                                  "reference": attack_group_references})
    documents.append(document)

# The 'documents' list now contains the detailed documents for each MITRE ATT&CK group

100%|██████████| 152/152 [01:18<00:00,  1.93it/s]


In [4]:
print(documents[0])

page_content='admin@338 (MITRE ATT&CK Group ID: G0018) is a threat group that admin@338 is a China-based cyber threat group. It has previously used newsworthy events as lures to deliver malware and has primarily targeted organizations involved in financial, economic, and trade policy, typically using publicly available RATs such as PoisonIvy, as well as some non-public backdoors. [1].

Contributors: Tatsuya Daitoku, Cyber Defense Institute, Inc.

Version: 1.2
Created: 2017-05-31
Last Modified: 2020-03-18

Associated Groups: None

Techniques Used:
- Account Discovery (ID: T1087): admin@338 actors used the following commands following exploitation of a machine with LOWBALL malware to enumerate user accounts: net user >> %temp%\download net user /domain >> %temp%\download [1] (Sub-technique: Local Account (ID: T1087.001))
- Command and Scripting Interpreter (ID: T1059): Following exploitation with LOWBALL malware, admin@338 actors created a file containing a list of commands to be execute

In [5]:
from langchain_community.vectorstores import OpenSearchVectorSearch
from langchain_openai import OpenAIEmbeddings
from opensearchpy import OpenSearch
import urllib3

# Shut up the SSL warnings
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# Delete the index if it exists
opensearch = OpenSearch(
    hosts=[f"{os.getenv('OPENSEARCH_HOST')}:{os.getenv('OPENSEARCH_PORT')}"],
    http_auth=(os.getenv('OPENSEARCH_ACCOUNT_ID'), os.getenv('OPENSEARCH_ACCOUNT_PASSWORD')),
    use_ssl = True,
    verify_certs = False,
    ignore_ssl_warnings = True
)

if opensearch.indices.exists(index=os.getenv('OPENSEARCH_INDEX_NAME')):
    opensearch.indices.delete(index=os.getenv('OPENSEARCH_INDEX_NAME'))
    print(f"Deleted the index {os.getenv('OPENSEARCH_INDEX_NAME')} because it already exists")

print(f"Creating an index with the name {os.getenv('OPENSEARCH_INDEX_NAME')}")

vectorstore = OpenSearchVectorSearch.from_documents(
    index_name=os.getenv('OPENSEARCH_INDEX_NAME'),
    documents=documents,
    embedding=OpenAIEmbeddings(),
    opensearch_url=f"http://{os.getenv('OPENSEARCH_HOST')}:{os.getenv('OPENSEARCH_PORT')}",
    http_auth=(os.getenv('OPENSEARCH_ACCOUNT_ID'), os.getenv('OPENSEARCH_ACCOUNT_PASSWORD')),
    use_ssl = True,
    verify_certs = False,
    ignore_ssl_warnings = True,
)

print(f"Index created with {len(documents)} documents")



Deleted the index mdr_threatghoul_cti_rag_mitre because it already exists
Creating an index with the name mdr_threatghoul_cti_rag_mitre
Index created with 152 documents


In [16]:
# Extract the ontology(top keywords) from the MITRE ATT&CK group information
from MITREAttackScrapper.cti.groups import MITREAttackCTIGroups
from typing import Any
import json

# Extract the list of MITRE ATT&CK group information and save it to the cti_group.json
attack_group_information: List[Dict[str, Any]] = MITREAttackCTIGroups.get_list()

# Extract keywords(ontologies) from each description
def extract_keywords(description: str) -> List[str]:
    keywords = set()
    words = description.split()
    for word in words:
        if len(word) > 3:
            keywords.add(word.lower())

    common_keywords: List[str] = [
        'group', 'groups', 'name', 'id', 'associated', 'description', 'url', 'threat',
        'since', 'least', 'targeted', 'including', 'based', 'primarily', 'well',
        'has', 'have', 'been', 'that', 'this', 'with', 'and', 'for', 'the', 'are', 'was',
    ]

    for keyword in common_keywords:
        if keyword in keywords:
            keywords.remove(keyword)

    return list(keywords)

attack_group_ontology: Dict[str, List[str]] = {}
for attack_group in attack_group_information:
    attack_group_id: str = attack_group['id']
    attack_group_description: str = attack_group['description']
    attack_group_ontology[attack_group_id] = extract_keywords(attack_group_description)

# Print the first five ontology
from pprint import pprint
pprint(dict(list(attack_group_ontology.items())[:5]))

# A function to use ontology for query expansion
def expand_query(query: str, ontology: Dict[str, List[str]]) -> str:
    expanded_terms: List[str] = []
    for group_id, keywords in ontology.items():
        for keyword in keywords:
            if keyword in query.lower():
                expanded_terms.append(group_id)
                break

    expanded_query: str = query + " " + " ".join(expanded_terms)
    return expanded_query

{'G0018': ['policy,',
           'some',
           'available',
           'backdoors.',
           'deliver',
           'admin@338',
           'used',
           'newsworthy',
           'organizations',
           'previously',
           'trade',
           'malware',
           'lures',
           'events',
           'china-based',
           'group.',
           'economic,',
           'using',
           'non-public',
           'publicly',
           'financial,',
           'typically',
           'rats',
           'cyber',
           'involved',
           'poisonivy,',
           'such'],
 'G0130': ['defacement',
           '2010',
           'transitioned',
           'team',
           'believed',
           'malware-based',
           'ajax',
           'campaigns',
           'anti-censorship',
           'iran.',
           'espionage',
           'base',
           '2014',
           'iranian',
           'users',
           'that',
           'operating',
        

In [23]:
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import json

@traceable(name="retrieve_relevant_documents")
def retrieve_relevant_documents(query: str, top_k: int = 3) -> List[Document]:

    # Step 1: Query Expansion
    expanded_query: str = expand_query(query, attack_group_ontology)

    # Step 2: Retrieve relevant documents with 3 times the top_k
    retrieved_documents: List[Document] = vectorstore.similarity_search(expanded_query, k=top_k*3)

    # Step 3: Embed the expanded query and retrieved documents
    embedding_model: SentenceTransformer = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
    expanded_query_embedding: List[float] = embedding_model.encode([expanded_query])[0]
    document_embeddings: List[List[float]] = embedding_model.encode([doc.page_content for doc in retrieved_documents])

    # Step 4: Calculate the cosine similarity between the expanded query and retrieved documents
    # The cosine similarity is calculated between the expanded query and the retrieved documents
    cosine_similarities: List[float] = cosine_similarity([expanded_query_embedding], document_embeddings)[0]

    # Step 5: Sort the documents based on cosine similarity and return the top_k documents
    sorted_documents: List[Document] = [doc for _, doc in sorted(zip(cosine_similarities, retrieved_documents), reverse=True)]
    return sorted_documents[:top_k]

# Example usage
user_question = "Which groups are based on the North Korean threat actor?"
relevant_docs = retrieve_relevant_documents(user_question)

print(f"Retrieved {len(relevant_docs)} relevant documents:")
for index, document in enumerate(relevant_docs):
    print(f"\nDocument {index + 1}:")
    print(f"Content: {document.page_content}...") 
    print(f"Metadata: {document.metadata}")

Retrieved 3 relevant documents:

Document 1:
Content: Kimsuky (MITRE ATT&CK Group ID: G0094) is a threat group that Kimsuky is a North Korea-based cyber espionage group that has been active since at least 2012. The group initially focused on targeting South Korean government entities, think tanks, and individuals identified as experts in various fields, and expanded its operations to include the United States, Russia, Europe, and the UN. Kimsuky has focused its intelligence collection activities on foreign policy and national security issues related to the Korean peninsula, nuclear policy, and sanctions.[1][2][3][4][5].

Contributors: Taewoo Lee, KISA, Dongwook Kim, KISA

Version: 4.0
Created: 2019-08-26
Last Modified: 2024-04-17

Associated Groups:
- Black Banshee: [3][4]
- Velvet Chollima: [9][10][4]
- Emerald Sleet: [11]
- THALLIUM: [3][4]

Techniques Used:
- Account Manipulation (ID: T1098): Kimsuky has added accounts to specific groups with net localgroup . [12] (Sub-technique: No

In [25]:
# Generation
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.schema import StrOutputParser

# TODO: Need to implement the extract_references and format_references functions more accurately
def extract_references(docs: List[Document]) -> Dict[str, str]:
    references = {}
    for doc in docs:
        if 'metadata' in doc:
            for ref_id, ref_info in doc.metadata.get('reference', {}).items():
                references[f"{doc.metadata['id']}:{ref_id}"] = f"{ref_info['reference name']} ({ref_info['url']})"
    return references

def format_references(references: Dict[str, str]) -> str:
    formatted_refs = "\n".join([f"[{key}]: {value}" for key, value in references.items()])
    return formatted_refs


# Improved prompt template
prompt = PromptTemplate.from_template(
    """You are an assistant for question-answering tasks. 
Use the following pieces of retrieved context to answer the question. 
If you don't know the answer, just say that you don't know. 
Answer in English, and include references as indicated in the context.

Question: {question} 
Context: {context} 

Answer:"""
)

llm = ChatOpenAI(model_name="gpt-4o", 
                 temperature=0, 
                 callback_manager=callback_manager)

# Function to generate the answer using the LLM
@traceable(name="rag_qa_chain")
def rag_qa_chain(query: str) -> str:
    # Retrieve relevant documents
    relevant_docs = retrieve_relevant_documents(query)

    # Extract and format references
    references = extract_references(relevant_docs)
    formatted_references = format_references(references)

    # Combine the retrieved documents into a single context string
    context = "\n\n".join([doc.page_content for doc in relevant_docs])
    
    # Create the chain
    chain = (
        {"context": lambda x: context, "question": lambda x: x}
        | prompt
        | llm
        | StrOutputParser()
    )
    
    # Run the chain
    answer = chain.invoke(query)
    
    # Combine answer with references
    final_answer = f"{answer}\n\nReferences:\n{formatted_references}"
    return final_answer

# Example usage
user_question = "What race is APT32 made up of (e.g. Americans)?"
answer = rag_qa_chain(user_question)

print(f"User question: {user_question}")
print(f"Answer: {answer}")

Parent run 2d79a9c3-0390-4d42-a636-91b6567ac2ba not found for run 9cf5ff34-ab12-4599-9a63-d55e31e9ff33. Treating as a root run.


User question: What race is APT32 made up of (e.g. Americans)?
Answer: APT32, also known as OceanLotus, is a suspected Vietnam-based threat group. The group is believed to be composed of individuals from Vietnam and has been active since at least 2014, targeting various private sector industries, foreign governments, dissidents, and journalists, particularly in Southeast Asian countries such as Vietnam, the Philippines, Laos, and Cambodia ([MITRE ATT&CK](https://attack.mitre.org/groups/G0050/)).

References:

