In [34]:
import os
import json
import pickle
import random
import logging
import numpy as np
import pandas as pd
from glob import glob
from tqdm.notebook import tqdm
from sklearn.decomposition import PCA
from sklearn.metrics import classification_report

import dgl
import dgl.nn as dglnn
from dgl.nn import GraphConv, GATConv, SAGEConv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, lr_scheduler
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup

os.environ['CUDA_VISIBLE_DEVICES'] = "1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') 

In [2]:
with open("./relations.txt") as fp:
    relations = [r.strip() for r in fp.readlines()]
with open("./label2index.pkl", "rb") as fp:
    label2index = pickle.load(fp)
with open("./index2label.pkl", "rb") as fp:
    index2label = pickle.load(fp)

# Make dataset

In [3]:
def get_value(event):
    global type2attr

    srcUUID = event["srcNode"]["UUID"]
    srcType = event["srcNode"]["Type"]
    srcAttr = event["srcNode"][type2attr[srcType]]
    dstUUID = event["dstNode"]["UUID"] if event["dstNode"] != None else srcUUID
    dstType = event["dstNode"]["Type"] if event["dstNode"] != None else srcType
    dstAttr = event["dstNode"][type2attr[dstType]] if event["dstNode"] != None else srcAttr
    return srcUUID, srcAttr, dstUUID, dstAttr, event["relation"], event["label"]

In [4]:
def make_dataset(dataset):
    global node_ent2idx, edge_ent2idx, node_ent2emb, edge_ent2emb

    data_list = []
    for p in tqdm(dataset):
        with open(p) as fp:
            events = json.load(fp)

        nodes = set()
        edges = []
        relations = []
        labels = []
        uuid2res = {}
        for e in events:
            srcUUID, srcAttr, dstUUID, dstAttr, rel, label = get_value(e)

            uuid2res[srcUUID], uuid2res[dstUUID] = srcAttr, dstAttr
            nodes.add(srcUUID)
            nodes.add(dstUUID)
            edges.append((srcUUID, dstUUID))
            relations.append(edge_ent2idx[rel])
            labels.append(label2index[label])
        nodes = list(nodes)  
        node_feat = [torch.tensor(node_ent2emb[node_ent2idx[uuid2res[uuid]]], dtype=torch.float32) for uuid in nodes]
        edge_attr = [torch.tensor(edge_ent2emb[idx], dtype=torch.float32) for idx in relations]

        src = [nodes.index(src_uuid) for src_uuid, dst_uuid in edges]
        dst = [nodes.index(dst_uuid) for src_uuid, dst_uuid in edges]
        edge_index = torch.tensor([src, dst], dtype=torch.long)

        
        data_list.append({
            "labels": labels,
            "num_nodes": len(nodes),
            "node_feat": node_feat,
            "edge_attr": edge_attr,
            "edge_index": edge_index
        })
    return data_list         

In [7]:
with open(f'./secureBERT_YR/node_vocab2index.pkl', 'rb') as fp:
    node_ent2idx = pickle.load(fp)
with open(f'./secureBERT_YR/edge_vocab2index.pkl', 'rb') as fp:
    edge_ent2idx = pickle.load(fp)
with open(f'./secureBERT_YR/nodes_ent2emb_256.pkl', 'rb') as fp:
    node_ent2emb = pickle.load(fp)
with open(f'./secureBERT_YR/edges_ent2emb_16.pkl', 'rb') as fp:
    edge_ent2emb = pickle.load(fp)

type2attr = {
    "Process": "Cmdline", 
    "File": "Name", 
    "Registry": "Key", 
    "Network": "Dstaddress"
}

random.seed(42)
trainset, validset, testset = [], [], []
for ability in tqdm(os.listdir('../data/Raw_dataset/')):
    paths = glob(f'../data/Raw_dataset/{ability}/number_*/expanded_instance.json')
    random.shuffle(paths)
    trainset.extend(paths[:80])
    validset.extend(paths[80:90])
    testset.extend(paths[90:])
    
train_data = make_dataset(trainset)
valid_data = make_dataset(validset)
test_data = make_dataset(testset)

  0%|          | 0/167 [00:00<?, ?it/s]

  0%|          | 0/13360 [00:00<?, ?it/s]

  0%|          | 0/1670 [00:00<?, ?it/s]

  0%|          | 0/1670 [00:00<?, ?it/s]

In [35]:
node_ent2idx

{'C:\\Users\\meaganaustin\\AppData\\Local\\Packages\\Microsoft.MSPaint_8wekyb3d8bbwe\\AC\\Temp\\*': 0,
 'C:\\Users\\andersonluis\\AppData\\Local\\Google\\Chrome\\User Data\\RecoveryImproved\\.git': 1,
 'C:\\Users\\xmoore\\Music\\*': 2,
 'C:\\Users\\alisonhaynes\\AppData\\Local\\Packages\\Microsoft.AccountsControl_cw5n1h2txyewy\\AC\\.git': 3,
 'C:\\Users\\perezrodney\\AppData\\Local\\Google\\Chrome\\User Data\\Default\\Extensions\\ghbmnnjooekpmoecnnnilnnbdlolhkhi\\1.41.0_0\\_locales\\cy\\.git': 4,
 'C:\\Users\\heatherhoward\\AppData\\Local\\Packages\\Microsoft.NET.Native.Framework.1.7_8wekyb3d8bbwe\\AC\\INetCookies': 5,
 'C:\\Users\\ezk\\Anaconda3\\Lib\\site-packages\\distributed\\widgets\\templates\\task_state.html.j2': 6,
 'C:\\Users\\ezk\\Anaconda3\\Lib\\site-packages\\sphinxcontrib\\htmlhelp\\locales\\eo\\LC_MESSAGES': 7,
 'C:\\Users\\steven78\\AppData\\Local\\Packages\\Microsoft.Windows.Search_cw5n1h2txyewy\\AC\\Microsoft\\Internet Explorer\\DOMStore\\J15GVKBA': 8,
 'C:\\Users\\eli

In [28]:
edge_ent2idx

{'Process Create': 0,
 'Process Start': 1,
 'CreateFile': 2,
 'SetBasicInformationFile': 3,
 'SetDispositionInformationEx': 4,
 'SetDispositionInformationFile': 5,
 'WriteFile': 6,
 'TCP Connect': 7,
 'TCP Send': 8,
 'UDP Send': 9,
 'TCP Disconnect': 10,
 'RegQueryKey': 11,
 'RegQueryValue': 12,
 'CloseFile': 13,
 'QueryAllInformationFile': 14,
 'QueryAttributeTagFile': 15,
 'QueryBasicInformationFile': 16,
 'QueryDirectory': 17,
 'QueryNetworkOpenInformationFile': 18,
 'ReadFile': 19,
 'TCP Receive': 20,
 'UDP Receive': 21,
 'RegCreateKey': 22,
 'RegSetValue': 23,
 'RegCloseKey': 24,
 'RegDeleteValue': 25,
 'RegOpenKey': 26}

In [39]:
node_ent2emb.shape

(1609759, 256)

In [40]:
edge_ent2emb.shape

(27, 16)

In [30]:
edge_ent2emb

array([[-3.47480446e-01,  1.73320808e-02,  8.85579437e-02,
         1.62743941e-01,  2.84419924e-01,  1.91913456e-01,
         2.93871085e-03,  5.25181778e-02, -4.25176099e-02,
         3.54782678e-04,  3.57708707e-02, -1.68276317e-02,
        -1.07634822e-02, -1.11624133e-02,  2.85948627e-02,
        -3.56004909e-02],
       [-3.60511810e-01,  1.88438538e-02,  9.39114019e-02,
         2.14220136e-01,  2.96153992e-01,  1.66956887e-01,
        -1.00340536e-02,  1.75624080e-02, -6.43368959e-02,
         1.50027927e-02,  4.54650857e-02, -4.25869832e-03,
        -2.36397795e-02, -1.11015961e-02, -5.89860976e-02,
         4.06207405e-02],
       [-6.41673267e-01,  9.66272205e-02,  2.94817299e-01,
         1.96637318e-01, -1.68469816e-01, -1.56227037e-01,
         2.53720731e-01,  2.25251481e-01,  5.19338548e-02,
         4.25831508e-03, -2.68688612e-02,  2.12919395e-02,
        -3.01366057e-02, -6.21061586e-03, -9.95471515e-03,
         1.99333467e-02],
       [ 4.90801871e-01, -1.58524066e

In [31]:
relations

['Process Create',
 'Process Start',
 'CreateFile',
 'SetBasicInformationFile',
 'SetDispositionInformationEx',
 'SetDispositionInformationFile',
 'WriteFile',
 'TCP Connect',
 'TCP Send',
 'UDP Send',
 'TCP Disconnect',
 'RegQueryKey',
 'RegQueryValue',
 'CloseFile',
 'QueryAllInformationFile',
 'QueryAttributeTagFile',
 'QueryBasicInformationFile',
 'QueryDirectory',
 'QueryNetworkOpenInformationFile',
 'ReadFile',
 'TCP Receive',
 'UDP Receive',
 'RegCreateKey',
 'RegSetValue',
 'RegCloseKey',
 'RegDeleteValue',
 'RegOpenKey']

In [32]:
label2index

{'T1105_1095434782a00c8a4772a11e625bcf5d_B': 0,
 'T1105_e6715e61f5df646692c624b3499384c4_I': 1,
 'T1003.003_9f73269695e54311dd61dc68940fb3e1_B': 2,
 'T1074.001_6469befa-748a-4b9c-a96d-f191fde47d89_I': 3,
 'T1112_ba6f6214dbd17c54001e0a163b60f151_I': 4,
 'T1059.003_f38e58deb7ad20b5538ca40db7b7b4f8_B': 5,
 'T1112_cab7b85611a290c0769546bfa9d6f962_I': 6,
 'T1036.003_f5ef8466e5ebcd2ae03f338d9416069c_I': 7,
 'T1059.001_55678719-e76e-4df9-92aa-10655bbd1cf4_I': 8,
 'T1087.001_6334877e8e3ba48f7835d4856d90a282_B': 9,
 'T1016_e8017c46-acb8-400c-a4b5-b3362b5b5baa_B': 10,
 'T1574.011_72249c1e9ffe7d8f30243d838e0791ca_I': 11,
 'T1074.001_4e97e699-93d7-4040-b5a3-2e906a58199e_I': 12,
 'T1105_eb814e03-811a-467a-bc6d-dcd453750fa2_B': 13,
 'T1087.001_6334877e8e3ba48f7835d4856d90a282_I': 14,
 'T1036.004_7de3d7b4922a7b996d8df36fb22bb118_B': 15,
 'T1003.003_f049b89533298c2d6cd37a940248b219_B': 16,
 'T1547.001_163b023f43aba758d36f524d146cb8ea_B': 17,
 'T1070.005_1f91076e2be2014cc7b4f1296de02fd6_B': 18,
 'T1564

In [33]:
index2label

{0: 'T1105_1095434782a00c8a4772a11e625bcf5d_B',
 1: 'T1105_e6715e61f5df646692c624b3499384c4_I',
 2: 'T1003.003_9f73269695e54311dd61dc68940fb3e1_B',
 3: 'T1074.001_6469befa-748a-4b9c-a96d-f191fde47d89_I',
 4: 'T1112_ba6f6214dbd17c54001e0a163b60f151_I',
 5: 'T1059.003_f38e58deb7ad20b5538ca40db7b7b4f8_B',
 6: 'T1112_cab7b85611a290c0769546bfa9d6f962_I',
 7: 'T1036.003_f5ef8466e5ebcd2ae03f338d9416069c_I',
 8: 'T1059.001_55678719-e76e-4df9-92aa-10655bbd1cf4_I',
 9: 'T1087.001_6334877e8e3ba48f7835d4856d90a282_B',
 10: 'T1016_e8017c46-acb8-400c-a4b5-b3362b5b5baa_B',
 11: 'T1574.011_72249c1e9ffe7d8f30243d838e0791ca_I',
 12: 'T1074.001_4e97e699-93d7-4040-b5a3-2e906a58199e_I',
 13: 'T1105_eb814e03-811a-467a-bc6d-dcd453750fa2_B',
 14: 'T1087.001_6334877e8e3ba48f7835d4856d90a282_I',
 15: 'T1036.004_7de3d7b4922a7b996d8df36fb22bb118_B',
 16: 'T1003.003_f049b89533298c2d6cd37a940248b219_B',
 17: 'T1547.001_163b023f43aba758d36f524d146cb8ea_B',
 18: 'T1070.005_1f91076e2be2014cc7b4f1296de02fd6_B',
 19: 'T

# Make Torch dataset

In [17]:
class GraphDataset(Dataset):
    def __init__(self, data_list, device):
        self.data_list = data_list
        self.device = device

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]
        return data

def collate(samples):
    data_list = samples
    batched_graphs = []
    for data in data_list:
        g = dgl.graph((data["edge_index"][0], data["edge_index"][1]), num_nodes=data["num_nodes"])

        g.ndata['feat'] = torch.stack(data["node_feat"])
        g.edata['feat'] = torch.stack(data["edge_attr"])
        g.edata['label'] = torch.tensor(data["labels"])  

        batched_graphs.append(g)
    
    return dgl.batch(batched_graphs)

In [18]:
train_GraphDataset = GraphDataset(train_data, device)
valid_GraphDataset = GraphDataset(valid_data, device)
test_GraphDataset = GraphDataset(test_data, device)

train_dataloader = DataLoader(train_GraphDataset, batch_size=32, shuffle=True, collate_fn=collate)
valid_dataloader = DataLoader(valid_GraphDataset, batch_size=32, shuffle=True, collate_fn=collate)
test_dataloader = DataLoader(test_GraphDataset, batch_size=32, shuffle=False, collate_fn=collate)

# Model

In [19]:
class GraphSAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(GraphSAGE, self).__init__()
        self.layer1 = dglnn.SAGEConv(in_dim, hidden_dim, 'pool')
        self.layer2 = dglnn.SAGEConv(hidden_dim, out_dim, 'pool')
        self.dropout = nn.Dropout(0.25)

    def forward(self, g, inputs):
        h = self.layer1(g, inputs)
        h = torch.relu(h)
        h = self.dropout(h)
        h = self.layer2(g, h)
        return h
    
class MLPPredictor(nn.Module):
    def __init__(self, out_feats, out_classes, edge_embedding_dim):
        super().__init__()
        self.W = nn.Linear(out_feats*2 + edge_embedding_dim, out_classes)

    def apply_edges(self, edges, edge_feat):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        h_e = edge_feat
        score = self.W(torch.cat([h_u, h_v, h_e], 1))
        return {'score': score}

    def forward(self, graph, h, edge_feat):
        with graph.local_scope():
            graph.ndata['h'] = h
            # graph.apply_edges(self.apply_edges)
            graph.apply_edges(lambda edges: self.apply_edges(edges, edge_feat))
            return graph.edata['score']
        
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_classes, edge_embedding_dim):
        super().__init__()
        self.sage = GraphSAGE(in_features, hidden_features, out_features)
        self.pred = MLPPredictor(out_features, num_classes, edge_embedding_dim)
      
    def forward(self, g, node_feat, edge_feat, return_logits=False):
        h = self.sage(g, node_feat)
        logits = self.pred(g, h, edge_feat)
        
        return logits

In [20]:
def same_seeds(seed = 42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def model_fn(batched_g, model, criterion, device, which_type='train'):
    """Forward a batch through the model."""
    batched_g = batched_g.to(device)
    
    labels = batched_g.edata['label'].to(device)    
    # logits = model(batched_g, batched_g.ndata['feat'].float())
    logits = model(batched_g, batched_g.ndata['feat'].float(), batched_g.edata['feat'].float())
    loss = criterion(logits, labels)

    output = torch.softmax(logits, dim=1)
    preds = output.argmax(1)
    
    accuracy = torch.mean((preds == labels).float())
        
    return loss, accuracy, preds

In [21]:
same_seeds(42)
model = Model(in_features=256, hidden_features=64, out_features=128, num_classes=len(label2index)-1, edge_embedding_dim = 16)
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-4)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=36, eta_min=0, last_epoch=- 1, verbose=False)
criterion = nn.CrossEntropyLoss()
model_save_path = "./model/GraphSAGE_emb256"
if not os.path.isdir(model_save_path):
    os.makedirs(model_save_path)

epochs = 50
best_val_loss = float('inf')
best_val_acc = float('-inf')
best_model_path = ""
for epoch in tqdm(range(epochs)):
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0    
    for data in train_dataloader:
        loss, accuracy, _ = model_fn(data, model, criterion, device, which_type='train')        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()

    avg_loss = total_loss / len(train_dataloader)
    avg_accuracy = total_accuracy / len(train_dataloader)
    logging.info(f'Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_accuracy:.4f}')
    
    # Validation Part
    model.eval()
    total_accuracy = 0.0
    total_loss = 0.0
    with torch.no_grad():
        for data in valid_dataloader:
            loss, accuracy, _ = model_fn(data, model, criterion, device, which_type='validation')
            total_accuracy += accuracy.item()
            total_loss += loss.item()

    avg_accuracy = total_accuracy / len(valid_dataloader)
    current_loss = total_loss / len(valid_dataloader)
    if current_loss < best_val_loss and avg_accuracy > best_val_acc:
        best_val_loss = current_loss
        best_val_acc = avg_accuracy
        best_model_path = f'{model_save_path}/epoch_{epoch}_loss_{current_loss:.4f}_acc_{avg_accuracy:.4f}'
    
    logging.info(f'Validation Loss: {current_loss:.4f} | Validation Accuracy: {avg_accuracy:.4f}\n')    
    torch.save(model.state_dict(), f'{model_save_path}/epoch_{epoch}_loss_{current_loss:.4f}_acc_{avg_accuracy:.4f}')

  0%|          | 0/50 [00:00<?, ?it/s]

2024-01-30 05:12:00 | INFO | Epoch 0 | Train Loss: 1.6625 | Train Accuracy: 0.6376
2024-01-30 05:12:01 | INFO | Validation Loss: 1.1291 | Validation Accuracy: 0.7830

2024-01-30 05:12:14 | INFO | Epoch 1 | Train Loss: 0.9077 | Train Accuracy: 0.8081
2024-01-30 05:12:15 | INFO | Validation Loss: 0.7537 | Validation Accuracy: 0.8370

2024-01-30 05:12:30 | INFO | Epoch 2 | Train Loss: 0.6695 | Train Accuracy: 0.8627
2024-01-30 05:12:31 | INFO | Validation Loss: 0.5639 | Validation Accuracy: 0.8961

2024-01-30 05:12:48 | INFO | Epoch 3 | Train Loss: 0.5125 | Train Accuracy: 0.9013
2024-01-30 05:12:50 | INFO | Validation Loss: 0.5629 | Validation Accuracy: 0.8841

2024-01-30 05:13:06 | INFO | Epoch 4 | Train Loss: 0.4445 | Train Accuracy: 0.9102
2024-01-30 05:13:08 | INFO | Validation Loss: 0.4546 | Validation Accuracy: 0.8974

2024-01-30 05:13:24 | INFO | Epoch 5 | Train Loss: 0.3912 | Train Accuracy: 0.9182
2024-01-30 05:13:26 | INFO | Validation Loss: 0.4606 | Validation Accuracy: 0.8976

In [22]:
# load the pretrained model
model.load_state_dict(torch.load(best_model_path))

model.to(device)
model.eval()

total = 0
correct = 0
true_labels = []
predicted_labels = []
with torch.no_grad():
    for data in test_dataloader:
        loss, accuracy, predicted = model_fn(data, model, criterion, device, which_type='test')
        labels = data.edata['label'].to(device)
        
        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(predicted.cpu().numpy())
                
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

logging.info(f'Test Accuracy: {100 * correct / total:.4f} %\n\n\n')

2024-01-30 05:25:29 | INFO | Test Accuracy: 99.0319 %





In [23]:
report_data = classification_report(true_labels, predicted_labels, output_dict=True)
report_df = pd.DataFrame(report_data).transpose()

output_path = "./result/GraphSAGE_emb256"
if not os.path.isdir(output_path):
    os.makedirs(output_path)
    
report_df.reset_index(inplace=True, names='label')
label_list = []
for idx, row in report_df.iterrows():
    if row["label"].isdigit():
        row["label"] = index2label[int(row["label"])]
    label_list.append(row["label"])
report_df["label"] = label_list
report_df.to_csv(f'{output_path}/result.csv', index=False)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [24]:
report_df

Unnamed: 0,label,precision,recall,f1-score,support
0,T1105_1095434782a00c8a4772a11e625bcf5d_B,1.000000,1.000000,1.000000,17.000000
1,T1105_e6715e61f5df646692c624b3499384c4_I,0.937500,0.750000,0.833333,1140.000000
2,T1003.003_9f73269695e54311dd61dc68940fb3e1_B,1.000000,1.000000,1.000000,15.000000
3,T1074.001_6469befa-748a-4b9c-a96d-f191fde47d89_I,1.000000,0.900000,0.947368,30.000000
4,T1112_ba6f6214dbd17c54001e0a163b60f151_I,1.000000,1.000000,1.000000,34.000000
...,...,...,...,...,...
272,T1120_7b9c7afaefa59aab759b49af0d699ac1_I,0.539216,0.509259,0.523810,108.000000
273,T1112_fd992e8ecfdac9b56dd6868904044827_I,1.000000,1.000000,1.000000,44.000000
274,accuracy,0.990319,0.990319,0.990319,0.990319
275,macro avg,0.843788,0.819779,0.820915,814055.000000
