# Introduction
- A provenance graph is a representation of kernel audit logs that capture all system events in the operating system.
- We used provenance graphs constructed from ATLAS dataset to run our experiment.
- The logs are collected from four hosts, we used three hosts for training and one host for testing.
- The target is to train a graph similarity model, then use it to find similarities between subgraphs extracted from the provenance graphs and attack graphs. 
- If a subgraph is similar to an attack graph with a certain threshold the model will raise an alarm.  
- We used parallelization in computing Graph edit distance between graph pairs since it consumes a lot of time to process sequentially and it’s a fully independent task. 

In [3]:
import networkx as nx 
from networkx.readwrite import json_graph
import json
import re
from statistics import mean
import random
import time
import matplotlib.pyplot as plt
from datetime import datetime
import dgl 
from sklearn import preprocessing
from nltk.tokenize import word_tokenize
import numpy as np
import pickle
import glob
import argparse
import os
from multiprocessing import Process
import torch 
from torch_geometric.data import InMemoryDataset  
from torch_geometric.data.collate import collate
import torch.nn.functional as F
from torch_geometric.data import Data

from src.ged import graph_edit_distance
import dask.bag as db

def read_json_graph(filename):
    with open(filename) as f:
        js_graph = json.load(f)
    return json_graph.node_link_graph(js_graph)
def ensure_dir(file_path):
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)

random.seed(123)

# Data Exploration

### Loading Provenance Graphs & malicious entities (IOCs)
- IOCs: stands for Indicators of Compromise

In [18]:
start_running_time = time.time()
training_provenance_graphs = {}
for graph_path in glob.glob('./dataset/raw_logs/provenanceGraph_training*'):
    graph_name = graph_path.split("_")[-1].replace(".json","")
    training_provenance_graphs[graph_name] = read_json_graph(graph_path)
    print("Training Provenance Graph: ", graph_name)
testing_provenance_graphs = {}
for graph_path in glob.glob('./dataset/raw_logs/provenanceGraph_testing*'):
    graph_name = graph_path.split("_")[-1].replace(".json","")
    testing_provenance_graphs[graph_name] = read_json_graph(graph_path)
    print("Testing Provenance Graph: ", graph_name)
atlas_ioc_f = "./dataset/raw_logs/atlas_ioc.json"
with open(atlas_ioc_f) as f:
    atlas_ioc = json.load(f)
iocs = [i for sublist in atlas_ioc.values() for i in sublist]

Training Provenance Graph:  S1
Training Provenance Graph:  S2
Training Provenance Graph:  S3
Testing Provenance Graph:  S4


### Exploring one Provenance Graph

In [25]:
def explore_graph(g):
    print("Number of nodes: ", g.number_of_nodes())
    print("Number of edges: ", g.number_of_edges())
    x  = list(g.nodes.data("type"))
    unique_nodes_types = list(set([y[1] for y in x]))
    print("\nUnique nodes type:",unique_nodes_types)
    for i in unique_nodes_types:
        print(i,": ", len([node_id for node_id, node_type in g.nodes.data("type") if node_type == i]) )
    x  = list(g.edges.data("type"))
    unique_edges_types = list(set([y[2] for y in x]))
    print("\nUnique edges type:",unique_edges_types)
    for i in unique_edges_types:
        print(i,": ", len([node_id for node_id,_, node_type in g.edges.data("type") if node_type == i]) )

In [26]:
explore_graph(training_provenance_graphs["S1"].copy())

Number of nodes:  7459
Number of edges:  14800

Unique nodes type: ['IP_Address', 'domain_name', 'connection', 'session', 'web_object', 'process', 'file']
IP_Address :  121
domain_name :  92
connection :  98
session :  1220
web_object :  550
process :  210
file :  5168

Unique edges type: ['fork', 'execute', 'executed', 'bind', 'write', 'read', 'sock_send', 'resolve', 'connect', 'connected_session', 'connected_remote_ip', 'delete', 'web_request', 'refer']
fork :  137
execute :  227
executed :  195
bind :  1138
write :  829
read :  6935
sock_send :  1188
resolve :  168
connect :  1203
connected_session :  1122
connected_remote_ip :  171
delete :  842
web_request :  551
refer :  94


# Data Preperation

### Label Suspicious nodes in all graphs based on IOCs.

In [36]:
def label_susp_nodes(processGraph,iocs): 
    for node_id, node_type in list(processGraph.nodes.data("type")):
        processGraph.nodes[node_id]["_display"] = str(node_id)  
    matched_ioc = {}
    for node_id, node_type in list(processGraph.nodes.data("type")):
        processGraph.nodes[node_id]["suspicous"] = False
        ioc = ''
        for ioc in iocs:
            if ioc in node_id.lower():
                processGraph.nodes[node_id]["suspicous"] = True 
                if ioc in matched_ioc: 
                        matched_ioc[ioc].append(node_id)
                else:
                    matched_ioc[ioc] = []
                    matched_ioc[ioc].append(node_id)
    count_matched_ioc = {}
    for n in matched_ioc:
        count_matched_ioc[n] = len(matched_ioc[n])
    matchedNodes = set([item for sublist in matched_ioc.values() for item in sublist])
    processNodes = set([node_id for node_id, node_type in processGraph.nodes.data("type") if node_type == 'process'])
    print("\nTotal number of matched nodes:", len(matchedNodes))
    print(count_matched_ioc) 
    return processGraph, matched_ioc , matchedNodes , processNodes

In [41]:
for g in training_provenance_graphs:
    print("\nlabel suspicious nodes for:", g)
    training_provenance_graphs[g],matched_ioc,matchedNodes,processNodes = label_susp_nodes(training_provenance_graphs[g],iocs)
for g in testing_provenance_graphs:
    print("\nlabel suspicious nodes for:", g)
    testing_provenance_graphs[g],matched_ioc,matchedNodes,processNodes = label_susp_nodes(testing_provenance_graphs[g],iocs)


label suspicious nodes for: S1

Total number of matched nodes: 25
{'0xalsaheel.com': 10, '192.168.223.3': 10, 'payload.exe': 5}

label suspicious nodes for: S2

Total number of matched nodes: 29
{'192.168.223.3': 13, '0xalsaheel.com': 10, 'aalsahee/index.html': 1, 'msf.doc': 1, 'msf.rtf': 2, 'payload.exe': 2}

label suspicious nodes for: S3

Total number of matched nodes: 27
{'192.168.223.3': 11, 'msf.rtf': 6, '0xalsaheel.com': 3, 'msf.exe': 6, 'payload.exe': 5, 'aalsahee/index.html': 1}

label suspicious nodes for: S4

Total number of matched nodes: 26
{'0xalsaheel.com': 3, '192.168.223.3': 11, 'msf.doc': 7, 'msf.rtf': 1, 'aalsahee/index.html': 1, 'payload.exe': 4, 'pypayload.exe': 4}


### Extract Suspicious Subgraphs from each provenane graph.
- Traversing the provenance graph forward & backword starting from suspicious nodes.
- Using AdaptiveBFS() algorithm to traverse.
- "depth": determine number of hubs in traversing the graph

In [48]:
# extract suspicious subgraph from provenance graph 
def extract_susp_prov_graph(networkx_graph,matched_ioc,matchedNodes,processNodes,depth = 4):
    start_time = time.time()
    seeds = [(k,n)  for k,n in sorted(matched_ioc.items(), key=lambda item: len(item[1]))]
    seed = seeds[0][0]
    covered = set()
    covered.add(seed)
    susp = nx.MultiDiGraph()
    suspGraphs = []

    #Traverse Forward & Backward 
    def AdaptiveBFS(root,depth = None):
        level = 0
        visited = set()
        currentLevel = [root]
        while currentLevel:
            nextLevel = set()
            for node in currentLevel:
                for _, nEdge in networkx_graph.out_edges(node):
                    if (nEdge not in visited) and (nEdge in matchedNodes or  nEdge in processNodes) :
                        edge_attr = networkx_graph.get_edge_data(node, nEdge).keys()
                        for key in edge_attr:
                            yield node , nEdge , key
                        nextLevel.add(nEdge)
                for pEdge,_ in networkx_graph.in_edges(node):
                    if (pEdge not in visited) and (pEdge in matchedNodes or  pEdge in processNodes) :
                        edge_attr = networkx_graph.get_edge_data( pEdge, node).keys()
                        for key in edge_attr:
                            yield pEdge, node , key
                        nextLevel.add(pEdge)
                visited.add(node)
            if depth:
                if (depth-level) > 0:
                    level += 1
                elif (depth-level) == 0:
                    break
                else:
                    break
            else:
                currentLevel = nextLevel

    #AdaptiveBFSS return one traversed subgraph
    #susp contain the aggregation of subgraphs, it start with empty graphs, stops when it covers all IoCs 
    def ExpandSearch(seedNodes,susp,depth = None):
        x = 0
        for node in seedNodes:
            x += 1
            startNode = node
            travNodes = []
            travNodes = AdaptiveBFS(startNode,depth)
            subgraphEdges = []
            for edge in iter(travNodes):
                subgraphEdges.append(edge)
            subgraph = networkx_graph.edge_subgraph(subgraphEdges).copy()
            subgraphEdges = None
            susp = nx.compose(susp,subgraph)
            subgraph = None
            for ioc , nodes in seeds:
                if ioc not in covered:
                    for node in nodes:
                        if susp.has_node(node):
                            covered.add(ioc)
                            continue
            remain_nodes = [ (ioc,nodes) for ioc,nodes in seeds if ioc not in covered ]  
            if not remain_nodes:
                suspGraphs.append(susp)
            else:
                covered.add(remain_nodes[0][0])
                # print("next remain node: ", remain_nodes[0][0])
                ExpandSearch(remain_nodes[0][1],susp,depth)
        susp = None
        return suspGraphs  
    
    suspGraphs = ExpandSearch(matched_ioc[seed],susp,depth)
    print("first seed:", seed)
    print("Number of subgraphs:", len(suspGraphs))
    print("Average number of nodes in subgraphs:",round(mean([supgraph.number_of_nodes() for supgraph in suspGraphs])))
    print("--- %s seconds ---" % (time.time() - start_time))
    return suspGraphs

In [73]:
suspSubGraphs_training = {}
suspSubGraphs_testing = {}

for g in training_provenance_graphs:
    print("\nExtract Suspicious subgraphs from: ", g)
    suspSubGraphs_training[g] = extract_susp_prov_graph(training_provenance_graphs[g].copy(),matched_ioc,matchedNodes,processNodes)

for g in testing_provenance_graphs:
    print("\nExtract Suspicious subgraphs from:", g)
    suspSubGraphs_testing[g] = extract_susp_prov_graph(testing_provenance_graphs[g].copy(),matched_ioc,matchedNodes,processNodes)


Extract Suspicious subgraphs from:  S1
first seed: msf.rtf
Number of subgraphs: 15
Average number of nodes in subgraphs: 3
--- 0.003475189208984375 seconds ---

Extract Suspicious subgraphs from:  S2
first seed: msf.rtf
Number of subgraphs: 15
Average number of nodes in subgraphs: 3
--- 0.003242969512939453 seconds ---

Extract Suspicious subgraphs from:  S3
first seed: msf.rtf
Number of subgraphs: 15
Average number of nodes in subgraphs: 3
--- 0.0036084651947021484 seconds ---

Extract Suspicious subgraphs from: S4
first seed: msf.rtf
Number of subgraphs: 6
Average number of nodes in subgraphs: 13
--- 0.18305253982543945 seconds ---


### Extract random subgraphs to generate training/testing sets.
- depth Default is to traverse 4 hubs.
- n_subgraph: number of subgraphs to be extracted.
- "depth": number of hubs in traversing the graph.
- "min_nodes": minimum number of nodes in extracted subgraph.
- "max_nodes": maximum number of nodes in extracted subgraph.
    


In [59]:
def extract_benign_graph(processGraph,n_subgraphs,min_nodes,max_nodes,depth = 4):
    start_time = time.time()
    benignSubGraphs = []
    #Traverse Forward & Backward 
    def BFS(root,depth = None):
        level = 0
        visited = set()
        currentLevel = [root]
        while currentLevel:
            nextLevel = set()
            for node in currentLevel:
                for _, nEdge in processGraph.out_edges(node):
                    if nEdge not in visited:
                        edge_attr = processGraph.get_edge_data(node, nEdge).keys()
                        for key in edge_attr:
                            yield node , nEdge , key
                        nextLevel.add(nEdge)
                for pEdge,_ in processGraph.in_edges(node):
                    if pEdge not in visited:
                        edge_attr = processGraph.get_edge_data( pEdge, node).keys()
                        for key in edge_attr:
                            yield pEdge, node , key
                        nextLevel.add(pEdge)
                visited.add(node)
            if depth:
                if (depth - level) > 0:
                    level += 1
                elif (depth-level) == 0:
                    break
                else:
                    break
            else:
                currentLevel = nextLevel
    
    benign_nodes = list(set([node_id for node_id, is_suspicious in processGraph.nodes.data("suspicous") if is_suspicious == False]))
    random.shuffle(benign_nodes)
    for seed in benign_nodes:
        travNodes = BFS(seed,depth)
        subgraphEdges = []
        for edge in iter(travNodes):
            subgraphEdges.append(edge)
        subgraph = processGraph.edge_subgraph(subgraphEdges).copy()
        if max_nodes >= subgraph.number_of_nodes() >= min_nodes:
            benignSubGraphs.append(subgraph)
        if len(benignSubGraphs) >= n_subgraphs:
            break
    print("Number of benign subgraphs:", len(benignSubGraphs))
    print("Average number of nodes in benign subgraphs:",round(mean([supgraph.number_of_nodes() for supgraph in benignSubGraphs])))
    print("Max number of nodes in benign subgraphs:",max([supgraph.number_of_nodes() for supgraph in benignSubGraphs]))
    print("Min number of nodes in benign subgraphs:",min([supgraph.number_of_nodes() for supgraph in benignSubGraphs]))
    print("--- %s seconds ---" % (time.time() - start_time))
    processGraph = None
    return benignSubGraphs

In [74]:
benignSubGraphs_training = {}
benignSubGraphs_testing = {}
n_subgraphs = 250
min_nodes = 5
max_nodes = 40
for g in training_provenance_graphs:
    print("\nExtract Benign Subgraphs from:", g)
    benignSubGraphs_training[g] = extract_benign_graph(training_provenance_graphs[g].copy(),n_subgraphs,min_nodes,max_nodes)
for g in testing_provenance_graphs:
    print("\nExtract Benign Subgraphs from", g)
    benignSubGraphs_testing[g] = extract_benign_graph(testing_provenance_graphs[g].copy(),n_subgraphs,min_nodes,max_nodes)



Extract Benign Subgraphs from: S1
Number of benign subgraphs: 250
Average number of nodes in benign subgraphs: 11
Max number of nodes in benign subgraphs: 39
Min number of nodes in benign subgraphs: 5
--- 1.0541822910308838 seconds ---

Extract Benign Subgraphs from: S2
Number of benign subgraphs: 250
Average number of nodes in benign subgraphs: 8
Max number of nodes in benign subgraphs: 38
Min number of nodes in benign subgraphs: 5
--- 0.8412396907806396 seconds ---

Extract Benign Subgraphs from: S3
Number of benign subgraphs: 250
Average number of nodes in benign subgraphs: 11
Max number of nodes in benign subgraphs: 38
Min number of nodes in benign subgraphs: 5
--- 0.4946446418762207 seconds ---

Extract Benign Subgraphs from S4
Number of benign subgraphs: 250
Average number of nodes in benign subgraphs: 9
Max number of nodes in benign subgraphs: 40
Min number of nodes in benign subgraphs: 5
--- 1.7599563598632812 seconds ---


### Preproces subgraphs to prepare training/testing datasets
- One Hot-encoding nodes types
- Keep nodes types only to be suitable for the embedding algorithm GCN 
- Convert Networkx graph to torch dataset to easily feed the pytorch model


In [64]:
def preprocess_graph_Hot_Encoding(g):
    types = ["process","file", "ip_address" , "web_object", "connection", "session","domain_name"]
    mapping = {name: j for j, name in enumerate(g.nodes())}
    g = nx.relabel_nodes(g, mapping)
    x = torch.zeros(g.number_of_nodes(), dtype=torch.long)
    for node, info in g.nodes(data=True):
        x[int(node)] = types.index(info['type'].lower())
    x = F.one_hot(x, num_classes=len(types)).to(torch.float)
    for node in g.nodes():
        g.nodes[node]["label"] = x[node]
    dgl_graph = dgl.from_networkx(g,node_attrs=["label"])   
    g,x = None,None
    return dgl_graph

In [75]:
training_dataset = []
for g in training_provenance_graphs:
    for i in range(len(suspSubGraphs_training[g])):
        suspSubGraphs_training[g][i] = preprocess_graph_Hot_Encoding(suspSubGraphs_training[g][i])
    for i in range(len(benignSubGraphs_training[g])):    
        benignSubGraphs_training[g][i]= preprocess_graph_Hot_Encoding(benignSubGraphs_training[g][i])
    training_dataset = training_dataset + suspSubGraphs_training[g] +  benignSubGraphs_training[g]
    print("Encoded:", g)
testing_dataset = []
for g in testing_provenance_graphs:
    for i in range(len(suspSubGraphs_testing[g])):
        suspSubGraphs_testing[g][i] = preprocess_graph_Hot_Encoding(suspSubGraphs_testing[g][i])
    for i in range(len(benignSubGraphs_testing[g])):    
        benignSubGraphs_testing[g][i]= preprocess_graph_Hot_Encoding(benignSubGraphs_testing[g][i])
    print("Encoded:", g)
    testing_dataset = testing_dataset + suspSubGraphs_testing[g] +  benignSubGraphs_testing[g]


Encoded: S1
Encoded: S2
Encoded: S3
Encoded: S4


In [None]:
def convert_to_torch_data(training_graphs,testing_graphs):
    training_data_list = []
    testing_data_list = []
    ids = 0 
    for g in training_graphs:
        edge_index = torch.tensor([g.edges()[0].tolist(),g.edges()[1].tolist()])
        data = Data(edge_index= edge_index, i= ids)
        data.num_nodes = g.number_of_nodes()
        data.x = g.ndata['label']
        training_data_list.append(data)
        ids += 1
        print("done", ids)
    for g in testing_graphs:
        edge_index = torch.tensor([g.edges()[0].tolist(),g.edges()[1].tolist()])
        data = Data(edge_index=edge_index, i= ids)
        data.num_nodes = g.number_of_nodes()
        data.x = g.ndata['label']
        testing_data_list.append(data)
        ids += 1
        print("done", ids)

In [None]:
torch_training_set,torch_testing_set = convert_to_torch_data(training_dataset,testing_dataset)

# Experiment Setup
- For each host, we extract all suspicious subgraphs based on ground truth nodes, and 250 benign subgraphs randomly extracted from the provenance graph  
- Used 795 subgraphs for training and 256 subgraphs for testing. 
- Prepared dataset consists of 1000 benign subgraph and 39 suspicious subgraphs
    - That shouldn’t affect the similarity model accuracy since it predict similarity using GED as a target variable. It doesn’t predict maliciousness of a subgraph.
- The SimGNN model is trained on comination of pairs of training set.
    - 632025 graph pairs
- The SimGNN model is tested with comination of pairs of testing set.
    - 203520 graph pairs

In [78]:
print("Total number of training Graphs",len(training_dataset))
print("Total number of testing Graphs",len(testing_dataset))

Total number of training Graphs 795
Total number of testing Graphs 256


# Compute GED
- GED: stands for Graph Edit Distance.
- SimGNN model uses GED as a target label.
- We used a minimum distance of three algorithms to compute approximate GED as explained in SimGNN paper.
- Computing GED to prepare trainig/testing sets' labels is the most time-consuming task. Therefore we used DASK to parallelize the computation.  
- We will show computing GED with and without DASK for few samples compare time performance.
    - we will use only 5 training and 2 testing graph for demonstration

In [20]:
from src.ged import graph_edit_distance

In [84]:
graph_data = training_dataset[:5] + testing_dataset[:2]
n_training = len(training_dataset)
n_dataset = len(graph_data)
ged_matrix = torch.full((len(graph_data), len(graph_data)), float('inf'))


## Experiments to improve time performance

#### 1 - Using a matrix to store GED scores
- We started by storing each graph pair with it's GED score in disk which involve many data transfer.
- Then we change the approach to store only the graphs and store GED between all pairs in one matrix, that approach significantly improve the time performance. 

#### 2 - Using DASK Bag
- split per each pair of graphs 

In [88]:
def combination_m(n_dataset,n_training):
    combined_list=[]
    for i in range(n_training):
        for j in range(i,n_training):
            combined_list.append((i,j))       
    for i in range(n_training,n_dataset):    
        for j in range(n_training):
            combined_list.append((i,j))
    return combined_list

In [87]:
def ged_distance_dask(i,j):
    distance_beam,_,_ = graph_edit_distance(graph_data[i], graph_data[j], algorithm='beam', max_beam_size=2)
    distance_bipartite, _, _ = graph_edit_distance(graph_data[i], graph_data[j], algorithm='bipartite')
    distance_hausdorff, _, _ = graph_edit_distance(graph_data[i], graph_data[j], algorithm='hausdorff')
    distance = min(distance_beam,distance_bipartite,distance_hausdorff)
    print(i,j)
    return i,j,distance

In [94]:
start_time = time.time()
combined_list = combination_m(n_dataset,n_training)
graph_dask = db.from_sequence(combined_list, npartitions=10)
graph_GEDs = graph_dask.map(lambda x : ged_distance_dask(x[0],x[1])).compute()
for i,j,d in graph_GEDs:
    ged_matrix[i,j] = d
    if i < n_training:
        ged_matrix[j,i] = d
print("Done computing in : %s seconds ---" % (time.time() - start_time))

#### 3 - Improve Using DASK Bag 
- Split per one graph with all it's pairs
- Avoid computing pairs twice
    - if GED between (a,b) = x, then GED between (b,a) = x 

In [None]:
def ged_distance_dask(i):
    start_time = time.time()
    g1 = graph_data[i]
    geds_sample = []
    if i < n_training:
        for j in range(i, n_training):
            g2 = graph_data[j]
            distance_beam, _, _ = graph_edit_distance(g1, g2, algorithm='beam', max_beam_size=2)
            distance_bipartite, _, _ = graph_edit_distance(g1, g2, algorithm='bipartite')
            distance_hausdorff, _, _ = graph_edit_distance(g1, g2, algorithm='hausdorff')
            distance = min(distance_beam, distance_bipartite, distance_hausdorff)
            geds_sample.append((i, j, distance))
    else:
        for j in range(n_training):
            g2 = graph_data[j]
            distance_beam, _, _ = graph_edit_distance(g1, g2, algorithm='beam', max_beam_size=2)
            distance_bipartite, _, _ = graph_edit_distance(g1, g2, algorithm='bipartite')
            distance_hausdorff, _, _ = graph_edit_distance(g1, g2, algorithm='hausdorff')
            distance = min(distance_beam, distance_bipartite, distance_hausdorff)
            geds_sample.append((i, j, distance))
    ged_matrix_temp,g1,g2 = None,None,None
    return geds_sample

In [None]:
start_time = time.time()
list_indices = list(range(len(graph_data)))
graph_dask = db.from_sequence(list_indices, npartitions=10)
graph_GEDs = graph_dask.map(lambda x: ged_distance_dask(x)).compute()
graph_GEDs = [sample for geds_sample in graph_GEDs for sample in geds_sample]
for i, j, d in graph_GEDs:
    ged_matrix[i, j] = d
    if i < n_training:
        ged_matrix[j, i] = d
print("Done computing in : %s seconds ---" % (time.time() - start_time))

# Model Training 

### Hints:
- We will continue the demonstration with GED matrix computed for the whole dataset.
- We will only call by running commands the SimGNN model from here to show performance. 
    - We will not go through model architecture details
- We used 100 epochs for demonstration, in the complete setting we are using 10000 epochs
    - Models for complete setting saved and provided under "./model" folder.
    - Plots for the complete stting are provided with name "ATLAS_gcn10000.pdf" and "ATLAS_gcn_hist10000.pdf".

In [128]:
! python src/SimGNN/main.py --dataset ATLAS  --epochs 100  --histogram

+---------------------+------------+
|      Parameter      |   Value    |
| Batch size          | 128        |
+---------------------+------------+
| Bins                | 16         |
+---------------------+------------+
| Bottle neck neurons | 16         |
+---------------------+------------+
| Dataset             | ATLAS      |
+---------------------+------------+
| Dataset path        | ./dataset/ |
+---------------------+------------+
| Diffpool            | 0          |
+---------------------+------------+
| Dropout             | 0          |
+---------------------+------------+
| Epochs              | 100        |
+---------------------+------------+
| Filters 1           | 64         |
+---------------------+------------+
| Filters 2           | 32         |
+---------------------+------------+
| Filters 3           | 16         |
+---------------------+------------+
| Gnn operator        | gcn        |
+---------------------+------------+
| Histogram           | 1          |
+

Batches:  86%|██████████████████████████████▊     | 6/7 [00:06<00:01,  1.10s/it][A
Batches: 100%|████████████████████████████████████| 7/7 [00:06<00:00,  1.25it/s][A
Epoch (Loss=0.02792):  14%|██▊                 | 14/100 [02:00<11:36,  8.10s/it][A
Batches:   0%|                                            | 0/7 [00:00<?, ?it/s][A
Batches:  14%|█████▏                              | 1/7 [00:01<00:06,  1.09s/it][A
Batches:  29%|██████████▎                         | 2/7 [00:02<00:05,  1.05s/it][A
Batches:  43%|███████████████▍                    | 3/7 [00:03<00:04,  1.08s/it][A
Batches:  57%|████████████████████▌               | 4/7 [00:04<00:03,  1.07s/it][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:05<00:02,  1.10s/it][A
Batches:  86%|██████████████████████████████▊     | 6/7 [00:06<00:01,  1.13s/it][A
Batches: 100%|████████████████████████████████████| 7/7 [00:06<00:00,  1.22it/s][A
Epoch (Loss=0.02943):  15%|███                 | 15/100 [02:07<10:56,  7.73s

Batches:  86%|██████████████████████████████▊     | 6/7 [00:06<00:01,  1.03s/it][A
Batches: 100%|████████████████████████████████████| 7/7 [00:06<00:00,  1.33it/s][A
Epoch (Loss=0.01996):  31%|██████▏             | 31/100 [04:17<11:43, 10.20s/it][A
Batches:   0%|                                            | 0/7 [00:00<?, ?it/s][A
Batches:  14%|█████▏                              | 1/7 [00:01<00:06,  1.14s/it][A
Batches:  29%|██████████▎                         | 2/7 [00:02<00:05,  1.08s/it][A
Batches:  43%|███████████████▍                    | 3/7 [00:03<00:04,  1.09s/it][A
Batches:  57%|████████████████████▌               | 4/7 [00:04<00:03,  1.11s/it][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:05<00:02,  1.08s/it][A
Batches:  86%|██████████████████████████████▊     | 6/7 [00:06<00:01,  1.03s/it][A
Batches: 100%|████████████████████████████████████| 7/7 [00:06<00:00,  1.36it/s][A
Epoch (Loss=0.01881):  32%|██████▍             | 32/100 [04:23<10:19,  9.11s

Validation:  10%|██▋                        | 200/2000 [00:01<00:15, 112.72it/s][A[A

Validation:  15%|████                       | 300/2000 [00:01<00:10, 160.95it/s][A[A

Validation:  20%|█████▌                      | 400/2000 [00:03<00:16, 98.85it/s][A[A

Validation:  25%|██████▊                    | 500/2000 [00:03<00:11, 132.18it/s][A[A

Validation:  30%|████████                   | 600/2000 [00:05<00:13, 107.47it/s][A[A

Validation:  35%|█████████▍                 | 700/2000 [00:05<00:09, 137.97it/s][A[A

Validation:  40%|██████████▊                | 800/2000 [00:05<00:07, 166.07it/s][A[A

Validation:  45%|████████████▏              | 900/2000 [00:06<00:05, 186.65it/s][A[A

Validation:  50%|█████████████             | 1000/2000 [00:06<00:05, 198.62it/s][A[A

Validation:  55%|██████████████▎           | 1100/2000 [00:07<00:06, 131.46it/s][A[A

Validation:  60%|███████████████▌          | 1200/2000 [00:08<00:05, 154.39it/s][A[A

Validation:  65%|███████████████

Batches:  14%|█████▏                              | 1/7 [00:01<00:06,  1.03s/it][A
Batches:  29%|██████████▎                         | 2/7 [00:02<00:05,  1.10s/it][A
Batches:  43%|███████████████▍                    | 3/7 [00:03<00:04,  1.06s/it][A
Batches:  57%|████████████████████▌               | 4/7 [00:04<00:03,  1.17s/it][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:05<00:02,  1.12s/it][A
Batches:  86%|██████████████████████████████▊     | 6/7 [00:06<00:01,  1.07s/it][A
Batches: 100%|████████████████████████████████████| 7/7 [00:06<00:00,  1.30it/s][A
Epoch (Loss=0.01049):  68%|█████████████▌      | 68/100 [08:59<03:47,  7.10s/it][A
Batches:   0%|                                            | 0/7 [00:00<?, ?it/s][A
Batches:  14%|█████▏                              | 1/7 [00:01<00:06,  1.06s/it][A
Batches:  29%|██████████▎                         | 2/7 [00:02<00:05,  1.03s/it][A
Batches:  43%|███████████████▍                    | 3/7 [00:03<00:04,  1.08s

Batches:  14%|█████▏                              | 1/7 [00:01<00:07,  1.17s/it][A
Batches:  29%|██████████▎                         | 2/7 [00:02<00:05,  1.09s/it][A
Batches:  43%|███████████████▍                    | 3/7 [00:03<00:04,  1.07s/it][A
Batches:  57%|████████████████████▌               | 4/7 [00:04<00:03,  1.10s/it][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:05<00:02,  1.09s/it][A
Batches:  86%|██████████████████████████████▊     | 6/7 [00:06<00:01,  1.09s/it][A
Batches: 100%|████████████████████████████████████| 7/7 [00:06<00:00,  1.30it/s][A
Epoch (Loss=0.01029):  85%|█████████████████   | 85/100 [11:17<01:53,  7.56s/it][A
Batches:   0%|                                            | 0/7 [00:00<?, ?it/s][A
Batches:  14%|█████▏                              | 1/7 [00:01<00:06,  1.08s/it][A
Batches:  29%|██████████▎                         | 2/7 [00:02<00:05,  1.06s/it][A
Batches:  43%|███████████████▍                    | 3/7 [00:03<00:04,  1.09s

## Results of complete setting
- Results are comparable with results presented in SimGNN paper on different datasets.
- SimGNN is state of the art in GNN Graph Similarity Model problem.
![title](img/SimGNN_ATLAS_Evaluation.png)


# Tune parameters to improve the model accuracy 
- Training without histogram feature 

In [129]:
! python src/SimGNN/main.py --dataset ATLAS --plot  --epochs 100  

+---------------------+------------+
|      Parameter      |   Value    |
| Batch size          | 128        |
+---------------------+------------+
| Bins                | 16         |
+---------------------+------------+
| Bottle neck neurons | 16         |
+---------------------+------------+
| Dataset             | ATLAS      |
+---------------------+------------+
| Dataset path        | ./dataset/ |
+---------------------+------------+
| Diffpool            | 0          |
+---------------------+------------+
| Dropout             | 0          |
+---------------------+------------+
| Epochs              | 100        |
+---------------------+------------+
| Filters 1           | 64         |
+---------------------+------------+
| Filters 2           | 32         |
+---------------------+------------+
| Filters 3           | 16         |
+---------------------+------------+
| Gnn operator        | gcn        |
+---------------------+------------+
| Histogram           | 0          |
+

Batches:   0%|                                            | 0/7 [00:00<?, ?it/s][A
Batches:  14%|█████▏                              | 1/7 [00:00<00:04,  1.44it/s][A
Batches:  29%|██████████▎                         | 2/7 [00:01<00:03,  1.41it/s][A
Batches:  43%|███████████████▍                    | 3/7 [00:02<00:02,  1.43it/s][A
Batches:  57%|████████████████████▌               | 4/7 [00:02<00:02,  1.44it/s][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:03<00:01,  1.44it/s][A
Batches:  86%|██████████████████████████████▊     | 6/7 [00:04<00:00,  1.41it/s][A
Epoch (Loss=0.02233):  16%|███▏                | 16/100 [01:24<06:36,  4.72s/it][A
Batches:   0%|                                            | 0/7 [00:00<?, ?it/s][A
Batches:  14%|█████▏                              | 1/7 [00:00<00:03,  1.63it/s][A
Batches:  29%|██████████▎                         | 2/7 [00:01<00:03,  1.45it/s][A
Batches:  43%|███████████████▍                    | 3/7 [00:02<00:02,  1.42i

Batches:  29%|██████████▎                         | 2/7 [00:01<00:03,  1.41it/s][A
Batches:  43%|███████████████▍                    | 3/7 [00:02<00:02,  1.49it/s][A
Batches:  57%|████████████████████▌               | 4/7 [00:02<00:02,  1.46it/s][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:03<00:01,  1.47it/s][A
Batches:  86%|██████████████████████████████▊     | 6/7 [00:04<00:00,  1.44it/s][A
Epoch (Loss=0.01484):  34%|██████▊             | 34/100 [02:55<05:29,  4.99s/it][A
Batches:   0%|                                            | 0/7 [00:00<?, ?it/s][A
Batches:  14%|█████▏                              | 1/7 [00:00<00:02,  2.20it/s][A
Batches:  29%|██████████▎                         | 2/7 [00:01<00:03,  1.62it/s][A
Batches:  43%|███████████████▍                    | 3/7 [00:01<00:02,  1.53it/s][A
Batches:  57%|████████████████████▌               | 4/7 [00:02<00:01,  1.55it/s][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:03<00:01,  1.52i

Batches:  57%|████████████████████▌               | 4/7 [00:02<00:02,  1.38it/s][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:03<00:01,  1.44it/s][A
Batches:  86%|██████████████████████████████▊     | 6/7 [00:04<00:00,  1.44it/s][A
Epoch (Loss=0.01323):  52%|██████████▍         | 52/100 [04:24<04:40,  5.85s/it][A
Batches:   0%|                                            | 0/7 [00:00<?, ?it/s][A
Batches:  14%|█████▏                              | 1/7 [00:00<00:04,  1.34it/s][A
Batches:  29%|██████████▎                         | 2/7 [00:01<00:03,  1.38it/s][A
Batches:  43%|███████████████▍                    | 3/7 [00:02<00:02,  1.41it/s][A
Batches:  57%|████████████████████▌               | 4/7 [00:02<00:02,  1.44it/s][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:03<00:01,  1.44it/s][A
Batches:  86%|██████████████████████████████▊     | 6/7 [00:04<00:00,  1.49it/s][A
Batches: 100%|████████████████████████████████████| 7/7 [00:04<00:00,  1.98i

Validation:  85%|██████████████████████    | 1700/2000 [00:06<00:01, 299.11it/s][A[A

Validation:  90%|███████████████████████▍  | 1800/2000 [00:06<00:00, 305.20it/s][A[A

Validation:  95%|████████████████████████▋ | 1900/2000 [00:06<00:00, 320.56it/s][A[A

Validation: 100%|██████████████████████████| 2000/2000 [00:07<00:00, 314.30it/s][A[A

                                                                                [A[A
Batches:   0%|                                            | 0/7 [00:00<?, ?it/s][A
Batches:  14%|█████▏                              | 1/7 [00:00<00:03,  1.50it/s][A
Batches:  29%|██████████▎                         | 2/7 [00:01<00:03,  1.48it/s][A
Batches:  43%|███████████████▍                    | 3/7 [00:02<00:02,  1.45it/s][A
Batches:  57%|████████████████████▌               | 4/7 [00:02<00:02,  1.42it/s][A
Batches:  71%|█████████████████████████▋          | 5/7 [00:03<00:01,  1.35it/s][A
Batches:  86%|██████████████████████████████▊     | 6/7 [

Validation:  15%|████                       | 300/2000 [00:01<00:08, 202.24it/s][A[A

Validation:  20%|█████▍                     | 400/2000 [00:01<00:06, 231.00it/s][A[A

Validation:  25%|██████▊                    | 500/2000 [00:02<00:06, 239.92it/s][A[A

Validation:  30%|████████                   | 600/2000 [00:02<00:05, 239.33it/s][A[A

Validation:  35%|█████████▍                 | 700/2000 [00:03<00:05, 224.87it/s][A[A

Validation:  40%|██████████▊                | 800/2000 [00:03<00:04, 246.32it/s][A[A

Validation:  45%|████████████▏              | 900/2000 [00:03<00:04, 261.09it/s][A[A

Validation:  50%|█████████████             | 1000/2000 [00:04<00:04, 245.83it/s][A[A

Validation:  55%|██████████████▎           | 1100/2000 [00:04<00:03, 263.00it/s][A[A

Validation:  60%|███████████████▌          | 1200/2000 [00:05<00:03, 246.41it/s][A[A

Validation:  65%|████████████████▉         | 1300/2000 [00:05<00:02, 274.09it/s][A[A

Validation:  70%|███████████████

- Results has improved in complete setting
![title](img/SimGNN_ATLAS_Evaluation_Without_histogram.png)

# Model Evaluation

- Evaluation has been done by pairing two attack graphs with all suspicious subgraphs and 25 benign subgraphs from each host. 

#### With Histogram Features.

In [134]:
! python src/SimGNN/main.py --dataset ATLAS --epochs 10000 --histogram --predict  --load ./model/atlas_simgnn_10000_hist.pt --threshold 0.50

+---------------------+------------------------------------+
|      Parameter      |               Value                |
| Batch size          | 128                                |
+---------------------+------------------------------------+
| Bins                | 16                                 |
+---------------------+------------------------------------+
| Bottle neck neurons | 16                                 |
+---------------------+------------------------------------+
| Dataset             | ATLAS                              |
+---------------------+------------------------------------+
| Dataset path        | ./dataset/                         |
+---------------------+------------------------------------+
| Diffpool            | 0                                  |
+---------------------+------------------------------------+
| Dropout             | 0                                  |
+---------------------+------------------------------------+
| Epochs              | 

#### Without histogram features.

In [135]:
! python src/SimGNN/main.py --dataset ATLAS --epochs 10000  --predict  --load ./model/atlas_simgnn_10000.pt --threshold 0.50

+---------------------+-------------------------------+
|      Parameter      |             Value             |
| Batch size          | 128                           |
+---------------------+-------------------------------+
| Bins                | 16                            |
+---------------------+-------------------------------+
| Bottle neck neurons | 16                            |
+---------------------+-------------------------------+
| Dataset             | ATLAS                         |
+---------------------+-------------------------------+
| Dataset path        | ./dataset/                    |
+---------------------+-------------------------------+
| Diffpool            | 0                             |
+---------------------+-------------------------------+
| Dropout             | 0                             |
+---------------------+-------------------------------+
| Epochs              | 10000                         |
+---------------------+-------------------------