## Import modules

In [5]:
from neo4j import GraphDatabase
from bs4 import BeautifulSoup
import requests
import re
import pandas as pd
from stix2 import Filter
from stix2 import MemoryStore
import requests

## Establish Connection with Database

In [189]:
class Neo4jConnection:
    
    def __init__(self, uri, user, pwd):
        self.__uri = uri
        self.__user = user
        self.__pwd = pwd
        self.__driver = None
        try:
            self.__driver = GraphDatabase.driver(self.__uri, auth=(self.__user, self.__pwd))
        except Exception as e:
            print("Failed to create the driver:", e)
        
    def close(self):
        if self.__driver is not None:
            self.__driver.close()
        
    def query(self, query, db=None):
        assert self.__driver is not None, "Driver not initialized!"
        session = None
        response = None
        try: 
            session = self.__driver.session(database=db) if db is not None else self.__driver.session() 
            response = list(session.run(query))
        except Exception as e:
            print("Query failed:", e)
        finally: 
            if session is not None:
                session.close()
        return response

In [190]:
conn = Neo4jConnection(uri="bolt://localhost:7687", user="neo4j", pwd="team")

## Create Attack nodes

In [11]:
attack_path = 'https://attack.mitre.org/docs/enterprise-attack-v10.1/enterprise-attack-v10.1-techniques.xlsx'
#defend_path = 'https://d3fend.mitre.org/ontologies/d3fend.csv'

In [12]:
attack_df = pd.read_excel(attack_path)
#defend_df = pd.read_csv(defend_path)

In [13]:
#parent_attacks = attack_df[attack_df["ID"].str.len() < 6]
#child_attacks = attack_df[attack_df["ID"].str.len() > 5]

In [14]:
for index, row in attack_df.iterrows():
    conn.query("CREATE (n:Attack{ID:'" + row[0] + "', name:'" + row[1] + "'})")

## Scrape Relation Data

In [15]:
attack_techniques = attack_df.iloc[:, 0]

In [16]:
attack_artifact = []
defend_artifact = []
defend_attack = []

In [17]:
for attack_t in attack_techniques:
    
    url = f'https://d3fend.mitre.org/offensive-technique/attack/{attack_t}'
    page = requests.get(url).content
    soup = BeautifulSoup(page, 'html.parser')

    try:
        match = soup.find('div', class_='hidden').text
    except:
        continue

    step_0 = re.sub(r'\s+', '', match)  # delete whitespaces
    step_1 = step_0.split(';') #split after ;
    step_2 = [string for string in step_1 if '-->' in string] #delete elements without '-->'
    for step in range(len(step_2)):
        step_3 = re.sub(r'\[[^]]*\]', '', step_2[step]) #delete things in parentheses '[]'
        step_4 = re.sub(r'-->', '', step_3) #delete '-->'
        step_5 = step_4.replace('-', '_') #replaces '-' with '_'
        step_6 = step_5.split('|') #split by seperator '|'#

        if step_6[0].startswith('T1'):
            if len(step_6[0]) > 5:
                step_6[0] = step_6[0][:5] + '.' + step_6[0][5:]
            attack_artifact.append(step_6)

        elif step_6[2].startswith('T1'):
            if len(step_6[2]) > 5:
                step_6[2] = step_6[2][:5] + '.' + step_6[2][5:]
            defend_attack.append(step_6)

        else:
            defend_artifact.append(step_6)

## Create DEFEND, Artifact Nodes and create relation

Attack -> Artifact

In [21]:
for row in attack_artifact:
    conn.query("MERGE (at:Attack{ID:'" + row[0] + "'}) \
               MERGE (ar:Artifact{name:'" + row[2] + "'}) \
               MERGE (at)-[:" + row[1] + "]->(ar)")

Defend -> Artifact

In [22]:
for row in defend_artifact:
    conn.query("MERGE (de:Defend{name:'" + row[0] + "'}) \
               MERGE (ar:Artifact{name:'" + row[2] + "'}) \
               MERGE (de)-[:" + row[1] + "]->(ar)")

Defend -> Attack

In [23]:
for row in defend_attack:
    conn.query("MERGE (de:Defend{name:'" + row[0] + "'}) \
               MERGE (at:Attack{ID:'" + row[2] + "'}) \
               MERGE (de)-[:" + row[1] + "]->(at)")

## Add Malware and create relations to Attack

Get the data from stix2 file.

In [139]:
def get_data_from_branch(domain):
    """get the ATT&CK STIX data from MITRE/CTI. Domain should be 'enterprise-attack', 'mobile-attack' or 'ics-attack'. Branch should typically be master."""
    stix_json = requests.get(f"https://raw.githubusercontent.com/mitre-attack/attack-stix-data/master/{domain}/{domain}.json").json()
    return MemoryStore(stix_data=stix_json["objects"])

def remove_revoked_deprecated(stix_objects):
    """Remove any revoked or deprecated objects from queries made to the data source"""
    # Note we use .get() because the property may not be present in the JSON data. The default is False
    # if the property is not set.
    return list(
        filter(
            lambda x: x.get("x_mitre_deprecated", False) is False and x.get("revoked", False) is False,
            stix_objects
        )
    )
def get_related(thesrc, src_type, rel_type, target_type, reverse=False):
    """build relationship mappings
       params:
         thesrc: MemoryStore to build relationship lookups for
         src_type: source type for the relationships, e.g "attack-pattern"
         rel_type: relationship type for the relationships, e.g "uses"
         target_type: target type for the relationship, e.g "intrusion-set"
         reverse: build reverse mapping of target to source
    """

    relationships = thesrc.query([
        Filter('type', '=', 'relationship'),
        Filter('relationship_type', '=', rel_type),
        Filter('revoked', '=', False),
    ])

    # See section below on "Removing revoked and deprecated objects"
    relationships = remove_revoked_deprecated(relationships)

    # stix_id => [ { relationship, related_object_id } for each related object ]
    id_to_related = {}

    # build the dict
    for relationship in relationships:
        if src_type in relationship.source_ref and target_type in relationship.target_ref:
            if (relationship.source_ref in id_to_related and not reverse) or (relationship.target_ref in id_to_related and reverse):
                # append to existing entry
                if not reverse:
                    id_to_related[relationship.source_ref].append({
                        "relationship": relationship,
                        "id": relationship.target_ref
                    })
                else:
                    id_to_related[relationship.target_ref].append({
                        "relationship": relationship,
                        "id": relationship.source_ref
                    })
            else:
                # create a new entry
                if not reverse:
                    id_to_related[relationship.source_ref] = [{
                        "relationship": relationship,
                        "id": relationship.target_ref
                    }]
                else:
                    id_to_related[relationship.target_ref] = [{
                        "relationship": relationship,
                        "id": relationship.source_ref
                    }]
    # all objects of relevant type
    if not reverse:
        targets = thesrc.query([
            Filter('type', '=', target_type),
            Filter('revoked', '=', False)
        ])
    else:
        targets = thesrc.query([
            Filter('type', '=', src_type),
            Filter('revoked', '=', False)
        ])

    # build lookup of stixID to stix object
    id_to_target = {}
    for target in targets:
        id_to_target[target.id] = target

    # build final output mappings
    output = {}
    for stix_id in id_to_related:
        value = []
        for related in id_to_related[stix_id]:
            if not related["id"] in id_to_target:
                continue  # targeting a revoked object
            value.append({
                "object": id_to_target[related["id"]],
                "relationship": related["relationship"]
            })
        output[stix_id] = value
    return output

def techniques_used_by_malware(thesrc):
    """return software_id => {technique, relationship} for each technique used by the software."""
    techniques_by_malware = get_related(thesrc, "malware", "uses", "attack-pattern")
    return {**techniques_by_malware}

def get_malware_name (malware_id):
    return src.get(malware_id).name

def get_malware_att_id(malware_id):
    return src.get(malware_id).external_references[0].external_id


src = get_data_from_branch("enterprise-attack")
malware_relation = techniques_used_by_malware(src)

Create the cyphers.

In [191]:
cyphers = []
for malware_id, techniques in malware_relation.items():
    for technique in techniques:
        cyphers.append(f"MATCH (at:Attack {{ID: '{technique['object'].external_references[0].external_id}'}}) MERGE (mal:Malware {{ID:'{get_malware_att_id(malware_id)}', name: '{get_malware_name(malware_id)}'}}) MERGE (mal)-[r:uses]->(at)")

Execute cyphers on db.

In [192]:
for cypher in cyphers:
    conn.query(cypher)