In [30]:
import numpy as np
import gzip
import json
import pandas as pd
import os
import sys
import urllib  
import re  
import nltk
import gensim

from sklearn.manifold import TSNE, MDS, Isomap
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt

from pyxll import xl_func

from gensim.models import Doc2Vec
import gensim.models.doc2vec
from collections import OrderedDict
import multiprocessing

from stix2 import FileSystemSource
from stix2 import Filter

from gensim.models.doc2vec import TaggedDocument
from collections import namedtuple

In [31]:
#replace_dots: get rid of any extension '.'s so they are not interpreted as full-stops
def remove_citations(text):
    text = re.sub(r'\(Citations?: \S+\)', '', text)
    return text

def remove_urls(text):
    text = re.sub(r'\[?\S+\]?\(?https?://\S+\)?', '', text)
    return text

def replace_dots(text):
    try:
        ind = text.index('.')
        while ind < len(text)-1:
            if not text[ind+1:ind+2] == ' ' and not text[ind+1:ind+2] == '"' and not text[ind+1:ind+2] == '\'':
                text = text[:ind] + '_' + text[ind+1:]
            try:
                ind = ind+1 + text[ind+1:].index('.')
            except:
                break
        return text
    except:
        return text

def remove_patterns(text):
    text = re.sub(r'\[?\S+\]?\(?[A-Za-z]?:\S+\)?', '', text)
    return text

def remove_snips(text):
    text = re.sub(r'/[\-*]{2,}(.*)[\-*]{2,}/s', '', text)
    return text

def remove_ipaddresses(text):
    text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '', text)
    return text

def remove_chars(text):
    to_remove = "This technique has been deprecated. Please see ATT&CK's Initial Access and Execution tactics for replacement techniques."
    text = text.replace(to_remove,'')
    text = re.sub('<[^>]*>', '', text.lower()).strip()
    text = re.sub('[^a-zA-Z\'\_]', ' ', text.lower())
    return text

#clean up the text
def clean_text(text):
    clean = remove_citations(text)
    clean = remove_urls(clean)
    clean = replace_dots(clean)
    clean = remove_patterns(clean)
    clean = remove_snips(clean)
    clean = remove_ipaddresses(clean)
    clean = remove_chars(clean)
    return clean

In [32]:
##Extract MITRE Data

fs = FileSystemSource('C:\pyxll\enterprise-attack')
filt1 = Filter('type', '=', 'relationship')
techniques_ent1 = fs.query([filt1])

fs = FileSystemSource('C:\pyxll\pre-attack')
filt1 = Filter('type', '=', 'relationship')
techniques_pre1 = fs.query([filt1])

In [33]:
##Clean MITRE Data

ent_dict1 = {obj['id']: (obj['relationship_type'],
                         (obj.get('description') or 'Nan' )
                         obj['target_ref']) for obj in techniques_ent1
             if obj['relationship_type'] == 'uses'}
pre_dict1 = {obj['id']: (obj['relationship_type'],
                         (obj.get('description') or 'Nan' ),
                         obj['target_ref']) for obj in techniques_pre1
             if obj['relationship_type'] == 'uses'}

ent_dict1 = {k:v for k,v in ent_dict1.items() if not 'Nan' in v}
pre_dict1 = {k:v for k,v in pre_dict1.items() if not 'Nan' in v}

ent_df1 = pd.DataFrame({'relationship_id':tuple(ent_dict1.keys()),
                        'values': tuple(ent_dict1.values())})
pre_df1 = pd.DataFrame({'relationship_id':tuple(pre_dict1.keys()),
                        'values': tuple(pre_dict1.values())})

ent_df1['type'] = ['ent_attack']*ent_df1.shape[0]
pre_df1['type'] = ['pre_attack']*pre_df1.shape[0]

techniques_df1 = pd.concat([ent_df1, pre_df1], ignore_index=True)
techniques_df1['relationship'] = techniques_df1['values'].apply(
    lambda x: x[0].strip())
techniques_df1['relation_description'] = techniques_df1['values'].apply(
    lambda x: x[1].strip())
techniques_df1['attack_id'] = techniques_df1['values'].apply(
    lambda x: x[2].strip())
techniques_df1['cleanText'] = techniques_df1['relation_description'].apply(
    lambda x: clean_text(str(x)))

In [34]:
all_zip1 = list(zip(techniques_df1['relationship_id'],techniques_df1['relation_description']))
final_dict1 = {key:value for key,value in all_zip1}

In [35]:
def preprocess(text):
    result = []
    for token in gensim.utils.simple_preprocess(text):
        if token not in gensim.parsing.preprocessing.STOPWORDS and len(token) > 2:
            nltk.bigrams(token)
            result.append(token)
    return result

In [36]:
##Create named tuples for gensim

Document = namedtuple('TaggedDocument', 'words tags object_type')

alldocs1 = []
for row_id, row in zip(techniques_df1['relationship_id'],techniques_df1[
    'cleanText']):
    words = preprocess(row)
    tags = [row_id]
    object_type = 'Relationship'
    alldocs1.append(Document(words, tags, object_type))

In [37]:
cores = multiprocessing.cpu_count()

assert gensim.models.doc2vec.FAST_VERSION > -1,
"This will be painfully slow otherwise"

model = Doc2Vec(dm=0, vector_size=100, negative=5, hs=0,
                min_count=10, sample=0, epochs=100, workers=cores)

model.build_vocab(alldocs1)

from random import shuffle
doc_list1 = alldocs1[:]  
shuffle(doc_list1)
    
model.train(doc_list1, total_examples=len(doc_list1), epochs=model.epochs)
word_models = simple_models[:]

In [38]:
@xl_func
def find_most_similar_relation(finding):
    result = model.docvecs.most_similar(positive=[
        model.infer_vector(finding)],topn=1)
    return result[0][0]

In [39]:
fs1 = FileSystemSource('C:\pyxll\enterprise-attack')
filt1 = Filter('type', '=', 'relationship')
techniques_ent1 = fs.query([filt1])

fs2 = FileSystemSource('C:\pyxll\pre-attack')
filt1 = Filter('type', '=', 'relationship')
techniques_pre1 = fs.query([filt1])

def get_technique_by_name(src, name):
    filt = [
        Filter('type', '=', 'attack-pattern'),
        Filter('id', '=', name)
    ]
    return src.query(filt)

@xl_func
def get_relation_desc(finding):
    relation = find_most_similar_relation(finding)
    for j in range(len(techniques_df1['relationship_id'])):
        if relation == str(techniques_df1['relationship_id'][j]):
            desc = str(techniques_df1['relation_description'][j])
    return desc

@xl_func       
def get_attack(finding):
    relation = find_most_similar_relation(finding)
    for j in range(len(techniques_df1['relationship_id'])):
        if relation == str(techniques_df1['relationship_id'][j]):
            attack_object = get_technique_by_name(fs1, techniques_df1['attack_id'][j])
            if not attack_object:
                attack_object = get_technique_by_name(fs2, techniques_df1['attack_id'][j])
    return attack_object

@xl_func
def get_attack_id(finding):
    relation = find_most_similar_relation(finding)
    for j in range(len(techniques_df1['relationship_id'])):
        if relation == str(techniques_df1['relationship_id'][j]):
            attack_object = get_technique_by_name(fs1, techniques_df1['attack_id'][j])
            if not attack_object:
                attack_object = get_technique_by_name(fs2, techniques_df1['attack_id'][j])
            external_ref = attack_object[0].get('external_references')
            external_id = external_ref[0]['external_id']
    return external_id

@xl_func
def get_attack_name(finding):
    relation = find_most_similar_relation(finding)
    for j in range(len(techniques_df1['relationship_id'])):
        if relation == str(techniques_df1['relationship_id'][j]):
            attack_object = get_technique_by_name(fs1, techniques_df1['attack_id'][j])
            if not attack_object:
                attack_object = get_technique_by_name(fs2, techniques_df1['attack_id'][j])
            name = attack_object[0].get('name')
    return name

@xl_func
def get_phase(finding):
    relation = find_most_similar_relation(finding)
    for j in range(len(techniques_df1['relationship_id'])):
        if relation == str(techniques_df1['relationship_id'][j]):
            attack_object = get_technique_by_name(fs1, techniques_df1['attack_id'][j])
            if not attack_object:
                attack_object = get_technique_by_name(fs2, techniques_df1['attack_id'][j])
            phase = attack_object[0]['kill_chain_phases'][0]['phase_name']
    return phase