In [1]:
! pwd

/media/Raid6_disk/bai/bai/final_synthesized/code


In [2]:
import os
import json
import copy
import pickle
import random
import logging
import linecache
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm

from sklearn.decomposition import PCA
from sklearn.metrics import classification_report

import dgl
import dgl.nn as dglnn
from dgl.nn import GraphConv, GATConv, SAGEConv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, lr_scheduler
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup
# from thundersvm import OneClassSVM

import warnings
warnings.filterwarnings('ignore')

os.environ['CUDA_VISIBLE_DEVICES'] = "1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') 

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [4]:
import re

def build_dictionary(file_path):
    with open(file_path, 'r') as file:
        next(file)
        # 使用正则表达式去除行末的数字
        dictionary = {re.sub(r'\s\d+$', '', line.strip()): index for index, line in enumerate(file)}
    return dictionary
    
file_path = '../data/3_openKE-old/label2id.txt'  # 替換為您檔案的路徑
label2index = build_dictionary(file_path)
index2label = {v: k for k, v in label2index.items()}

index2label

{0: 'T1059.001_702bfdd2-9947-4eda-b551-c3a1ea9a59a2_B',
 1: 'T1078.001_d0ca00832890baa1d42322cf70fcab1a_B',
 2: 'T1074.001_e6dfc7e89359ac6fa6de84b0e1d5762e_B',
 3: 'T1491_68235976-2404-42a8-9105-68230cfef562_B',
 4: 'T1016_14a21534-350f-4d83-9dd7-3c56b93a0c17_B',
 5: 'T1491_47d08617-5ce1-424a-8cc5-c9c978ce6bf9_I',
 6: 'T1074.001_4e97e699-93d7-4040-b5a3-2e906a58199e_I',
 7: 'T1040_6881a4589710d53f0c146e91db513f01_B',
 8: 'T1547.009_b6e5c895c6709fe289352ee23f062229_B',
 9: 'T1564.001_66a5fd5f244819181f074dd082a28905_B',
 10: 'T1047_f4b0b4129560ea66f9751275e82f6bab_B',
 11: 'T1112_257313a3c93e3bb7dfb60d6753b09e34_I',
 12: 'T1047_ac2764f7a67a9ce92b54e8e59b361838_B',
 13: 'T1518.001_33a24ff44719e6ac0614b58f8c9a7c72_B',
 14: 'T1204.002_522f3f35cd013e63830fa555495a0081_I',
 15: 'T1059.001_ccdb8caf-c69e-424b-b930-551969450c57_B',
 16: 'T1105_0856c235a1d26113d4f2d92e39c9a9f8_B',
 17: 'T1547_fe9eeee9a7b339089e5fa634b08522c1_I',
 18: 'T1574.001_63bbedafba2f541552ac3579e9e3737b_B',
 19: 'T1137.002

In [5]:
with open("../data/3_openKE/relation2id.txt") as fp:
    next(fp)
    
    relations = [' '.join(line.strip().split(' ')[:-1]) for line in fp.readlines()]

print(relations)
print(len(relations))

['Process Create', 'Process Start', 'CreateFile', 'SetBasicInformationFile', 'SetDispositionInformationEx', 'SetDispositionInformationFile', 'WriteFile', 'TCP Connect', 'TCP Send', 'UDP Send', 'TCP Disconnect', 'RegQueryKey', 'RegQueryValue', 'CloseFile', 'QueryAllInformationFile', 'QueryAttributeTagFile', 'QueryBasicInformationFile', 'QueryDirectory', 'QueryNetworkOpenInformationFile', 'ReadFile', 'TCP Receive', 'UDP Receive', 'RegCreateKey', 'RegSetValue', 'RegCloseKey', 'RegDeleteValue', 'RegOpenKey']
27


# Make Dataset

In [6]:
def get_value(event):
    global type2attr

    srcUUID = event["srcNode"]["UUID"]
    srcType = event["srcNode"]["Type"]
    srcAttr = event["srcNode"][type2attr[srcType]]
    dstUUID = event["dstNode"]["UUID"] if event["dstNode"] != None else srcUUID
    dstType = event["dstNode"]["Type"] if event["dstNode"] != None else srcType
    dstAttr = event["dstNode"][type2attr[dstType]] if event["dstNode"] != None else srcAttr
    return srcUUID, srcAttr, dstUUID, dstAttr, event["relation"], event["label"]

In [7]:
def make_dataset(dataset):
    global node_ent2idx, edge_ent2idx, node_ent2emb, edge_ent2emb

    data_list = []
    for p in tqdm(dataset, total=len(dataset), desc="Making dataset..."):
        events = dataset[p]

        nodes = set()
        edges = []
        relations = []
        labels = []
        uuid2res = {}
        for e in events:
            srcUUID, srcAttr, dstUUID, dstAttr, rel, label = get_value(e)

            uuid2res[srcUUID], uuid2res[dstUUID] = srcAttr, dstAttr
            nodes.add(srcUUID)
            nodes.add(dstUUID)
            edges.append((srcUUID, dstUUID))
            relations.append(edge_ent2idx[rel])
            labels.append(label2index[label])
        nodes = list(nodes)  
        node_feat = [torch.tensor(node_ent2emb[node_ent2idx[uuid2res[uuid]]], dtype=torch.float32) for uuid in nodes]
        edge_attr = [torch.tensor(edge_ent2emb[idx], dtype=torch.float32) for idx in relations]

        src = [nodes.index(src_uuid) for src_uuid, dst_uuid in edges]
        dst = [nodes.index(dst_uuid) for src_uuid, dst_uuid in edges]
        edge_index = torch.tensor([src, dst], dtype=torch.long)

        
        data_list.append({
            "labels": labels,
            "num_nodes": len(nodes),
            "node_feat": node_feat,
            "edge_attr": edge_attr,
            "edge_index": edge_index,
            "proc": p
        })
    return data_list         

In [9]:
# with open("../data/3_openKE/relation2id.txt") as fp:
#     next(fp)
#     edge_ent2idx = {re.sub(r'\s\d+$', '', line.strip()): index for index, line in enumerate(fp)}

# edge_ent2idx

In [10]:
# with open("../data/3_openKE/entity2id.pkl", 'rb') as fp:
#     node_ent2idx = pickle.load(fp)
# node_ent2idx

In [18]:
node_ent2idx['HKCR\\DirectShow\\MediaObjects\\bbeea841-0a63-4f52-a7ab-a9b3a84ed38a\\OutputTypes']

1380581

In [11]:
# DIM = 256
# embedding = 'transE'

# embedding = f'{embedding}_{DIM}'
# with open(f"../data/4_embedding/{embedding}.vec.json", "r") as f:
#     tmp = json.load(f)

# node_ent2emb = {idx:emb for idx, emb in enumerate(tmp["ent_embeddings.weight"])}
# edge_ent2emb = {idx:emb for idx, emb in enumerate(tmp["rel_embeddings.weight"])}

In [12]:
# len(node_ent2emb)

In [13]:
# len(edge_ent2emb)

In [14]:
type2attr = {
    "Process": "Cmdline", 
    "File": "Name", 
    "Registry": "Key", 
    "Network": "Dstaddress"
}

In [15]:
# if "" not in node_ent2idx:
#     node_ent2idx[""] = node_ent2idx["nan"]
#     # node_ent2idx[""] = 0

In [22]:
with open(f'../data/For_YR_pickle/entity2id.pkl', 'rb') as fp:
    node_ent2idx = pickle.load(fp)
with open(f'../data/For_YR_pickle/edge_ent2id.pkl', 'rb') as fp:
    edge_ent2idx = pickle.load(fp)
# with open(f'../data/For_YR_pickle/node_emb2id.pkl', 'rb') as fp:
#     node_ent2emb = pickle.load(fp)
# with open(f'../data/For_YR_pickle/edge_emb2id.pkl', 'rb') as fp:
#     edge_ent2emb = pickle.load(fp)
# node_ent2idx[""] = node_ent2idx["nan"]

In [28]:
with open("../data/For_YR_pickle/edge_emb2id.pkl", 'rb') as fp:
    edge_emb2id = pickle.load(fp)

In [29]:
len(edge_emb2id[0])

16

In [30]:
edge_ent2emb = {idx:emb for idx, emb in enumerate(edge_emb2id)}

In [32]:
with open("../data/For_YR_pickle/node_emb2id.pkl", 'rb') as fp:
    node_emb2id = pickle.load(fp)

In [33]:
node_ent2emb = {idx:emb for idx, emb in enumerate(node_emb2id)}

In [34]:
len(node_ent2emb[0])

256

In [35]:
len(node_ent2emb)

1497309

# Make Torch dataset

In [36]:
class GraphDataset(Dataset):
    def __init__(self, data_list, device):
        self.data_list = data_list
        self.device = device

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]
        return data

def collate(samples):
    data_list = samples
    batched_graphs = []
    for data in data_list:
        g = dgl.graph((data["edge_index"][0], data["edge_index"][1]), num_nodes=data["num_nodes"])

        g.ndata['feat'] = torch.stack(data["node_feat"])
        g.edata['feat'] = torch.stack(data["edge_attr"])
        g.edata['label'] = torch.tensor(data["labels"])  

        batched_graphs.append(g)
    
    return dgl.batch(batched_graphs)

# Model

In [37]:
class GraphSAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(GraphSAGE, self).__init__()
        self.layer1 = dglnn.SAGEConv(in_dim, hidden_dim, 'pool')
        self.layer2 = dglnn.SAGEConv(hidden_dim, out_dim, 'pool')
        self.dropout = nn.Dropout(0.25)

    def forward(self, g, inputs):
        h = self.layer1(g, inputs)
        h = torch.relu(h)
        h = self.dropout(h)
        h = self.layer2(g, h)
        return h
    
class MLPPredictor(nn.Module):
    def __init__(self, out_feats, out_classes, edge_embedding_dim):
        super().__init__()
        self.W = nn.Linear(out_feats*2 + edge_embedding_dim, out_classes)

    def apply_edges(self, edges, edge_feat):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        h_e = edge_feat
        score = self.W(torch.cat([h_u, h_v, h_e], 1))
        return {'score': score}

    def forward(self, graph, h, edge_feat):
        with graph.local_scope():
            graph.ndata['h'] = h
            # graph.apply_edges(self.apply_edges)
            graph.apply_edges(lambda edges: self.apply_edges(edges, edge_feat))
            return graph.edata['score']
        
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_classes, edge_embedding_dim):
        super().__init__()
        self.sage = GraphSAGE(in_features, hidden_features, out_features)
        self.pred = MLPPredictor(out_features, num_classes, edge_embedding_dim)
      
    def forward(self, g, node_feat, edge_feat, return_logits=False):
        h = self.sage(g, node_feat)
        logits = self.pred(g, h, edge_feat)
        
        return logits

In [38]:
def same_seeds(seed = 42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def model_fn(batched_g, model, criterion, device, which_type='train'):
    """Forward a batch through the model."""
    batched_g = batched_g.to(device)
    
    labels = batched_g.edata['label'].to(device)    
    # logits = model(batched_g, batched_g.ndata['feat'].float())
    logits = model(batched_g, batched_g.ndata['feat'].float(), batched_g.edata['feat'].float())
    loss = criterion(logits, labels)

    output = torch.softmax(logits, dim=1)
    preds = output.argmax(1)
    
    accuracy = torch.mean((preds == labels).float())
        
    return loss, accuracy, preds

# Main

In [39]:
same_seeds(42)

model = Model(in_features=256, hidden_features=512, out_features=1024, num_classes=len(label2index), edge_embedding_dim = 16).to(device)
# model.load_state_dict(torch.load("./model/GraphSAGE_emb256_0201/epoch_47_loss_0.0537_acc_0.9825"))
model.load_state_dict(torch.load("./graphSAGE_exp3/checkpoint_graphSAGE/best_model_GraphSAGE_transH_256-edge-small_batchsize-bigdim.pt"))

criterion = nn.CrossEntropyLoss()

RuntimeError: Error(s) in loading state_dict for Model:
	size mismatch for pred.W.weight: copying a param with shape torch.Size([278, 2304]) from checkpoint, the shape in current model is torch.Size([278, 2064]).

In [107]:
def labeling(events):
    tmp = set()
    labeled_events = []
    for e in events:
        relation = e["relation"]
        label = e["label"]

        if relation not in relations:
            continue
        if "Linked" in label:
            label = "T1204.002_522f3f35cd013e63830fa555495a0081"

        if e["label"] != "benign":
            if e["label"] not in tmp:
                tmp.add(e["label"])
                label = f"{label}_B"
            else:
                label = f"{label}_I"
        else:
            label = "O"

        e["label"] = label
        labeled_events.append(e)
    return labeled_events

def re_labeling(all_label, all_predict):
    global index2label
    y_true = copy.deepcopy(all_label)
    y_pred = copy.deepcopy(all_predict)

    y_true_re = []
    y_pred_re = []
    for i in range(len(y_true)):
        t = index2label[y_true[i]]
        p = index2label[y_pred[i]]
        t_ = t[:-2] if t != "O" else t
        p_ = p[:-2] if p != "O" else p
        y_true_re.append(t_)
        y_pred_re.append(p_)

    return y_true_re, y_pred_re

def numeric(file_path):
    idx = int(file_path.split('_')[-1].split('.')[0])
    return idx

In [108]:
def get_event_emb(events):
    event_embs = []
    for e in events:
        if e["relation"] not in relations:
            continue

        srcUUID, srcAttr, dstUUID, dstAttr, rel, label = get_value(e)
        emb_src = torch.tensor(node_ent2emb[node_ent2idx[srcAttr]], dtype=torch.float32)
        emb_dst = torch.tensor(node_ent2emb[node_ent2idx[dstAttr]], dtype=torch.float32)
        emb_rel = torch.tensor(edge_ent2emb[edge_ent2idx[rel]], dtype=torch.float32)

        event_emb = torch.cat((emb_src, emb_rel, emb_dst), dim=0)
        event_embs.append(event_emb)
    event_embs = torch.stack(event_embs)
    return event_embs

In [109]:
def anomaly_detection(malicious_uuid, proc2events, proc2attack, save_path):
    labels = []
    predicts = []
    for p in proc2events:
        label = 0 if p not in malicious_uuid else 1
        predict = 0 if p not in proc2attack else 1
        labels.append(label)
        predicts.append(predict)

    report_data = classification_report(labels, predicts, labels=[0,1], target_names=["benign", "attack"], output_dict=True)
    report_df = pd.DataFrame(report_data).transpose()
    report_df.reset_index(inplace=True, names='label')
    report_df.to_csv(f'{save_path}/anomaly_detection_result.csv', index=False)

In [110]:
with open('../data/campaign2suspicious_proc.json', 'r') as fp:
    campaign2suspicious_proc = json.load(fp)

In [111]:
for campaign in os.listdir("../data/SynthesizedCampaign_0128/"):
    save_path = f"./result/ToyDataset2/Synthesized_campaign_GraphSAGE_0201/{campaign}"
    if not os.path.isdir(save_path):
        os.makedirs(save_path)

    logging.info(f'Start {campaign} testing...')
    with open(f"../data/SynthesizedCampaign_0128/{campaign}/sequence_0/campaign_0/expand_malicious_process/expand_malicious_process.json") as fp:
        malicious_events = json.load(fp)
    malicious_uuid = [e["srcNode"]["UUID"] for e in malicious_events]
    malicious_uuid = list(set(malicious_uuid))
    
    paths = sorted(glob(f"../data/SynthesizedCampaign_0128/{campaign}/sequence_0/campaign_0/synthesized_event/synthesized_events_*.json"), key=numeric)
    campaign_events = []
    for p in paths:
        campaign_events.extend(linecache.getlines(p))
    campaign_events = [json.loads(e) for e in campaign_events]
    campaign_events = labeling(campaign_events)
    
    proc2events = {}
    for e in campaign_events:
        proc = e["srcNode"]["UUID"]
        proc2events.setdefault(proc, [])
        proc2events[proc].append(e)

    # Step1: detecting malicious process
    # attack_detection_model = pickle.load(open("./model/nu_0.06_p_0.784_r_0.880_f_0.829.pkl", 'rb'))
    # proc2attack = {}
    # for proc, events in tqdm(proc2events.items(), total=len(proc2events), desc="Detecting attack event"):    
    #     event_embs = get_event_emb(events)
    #     predict = attack_detection_model.predict(event_embs)
    #     attack_cnt = predict[predict == -1].size
    #     if attack_cnt != 0:
    #         proc2attack[proc] = events
    # anomaly_detection(malicious_uuid, proc2events, proc2attack, save_path)
    # logging.info(f"Num of raw process: {len(proc2events)}, Num of attack process: {len(proc2attack)}")
    proc2attack = {p:events for p, events in proc2events.items() if p in campaign2suspicious_proc[campaign]}
        
    # Step2: detecting TTP per each process
    test_data = make_dataset(proc2events)
    test_GraphDataset = GraphDataset(test_data, device)
    test_dataloader = DataLoader(test_GraphDataset, batch_size=32, shuffle=False, collate_fn=collate)
    
    save_torch_data_path = f"./torch_data/ToyDataset2/Synthesized_campaign_GraphSAGE_0201/{campaign}"
    if not os.path.isdir(save_torch_data_path):
        os.makedirs(save_torch_data_path)
    with open(f"{save_torch_data_path}/test.pkl", 'wb') as fp:
        pickle.dump(test_dataloader,fp)

    model.eval()
    total = 0
    correct = 0
    true_labels = []
    predicted_labels = []
    with torch.no_grad():
        for data in test_dataloader:
            loss, accuracy, predicted = model_fn(data, model, criterion, device, which_type='test')
            labels = data.edata['label'].to(device)

            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())
                    
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # result
    labels = sorted(list(set(true_labels)))
    if 165 in labels:
        labels.remove(165)
    targets = [index2label[l] for l in labels]
    report_data = classification_report(true_labels, predicted_labels, labels=labels, target_names=targets, output_dict=True)
    report_df = pd.DataFrame(report_data).transpose()
    report_df.reset_index(inplace=True, names='label')
    report_df.to_csv(f'{save_path}/result.csv', index=False)
    
    # result(no BIO)
    y_true_re, y_pred_re = re_labeling(true_labels, predicted_labels)
    labels = sorted(list(set(y_true_re)))
    if "O" in labels:
        labels.remove("O")
    report_data = classification_report(y_true_re, y_pred_re, labels=labels, output_dict=True, zero_division=0.0)
    report_df = pd.DataFrame(report_data).transpose()
    report_df.reset_index(inplace=True, names='label')
    report_df.to_csv(f'{save_path}/result_(no BIO).csv', index=False)


    logging.info(f'Test Accuracy: {100 * correct / total:.4f} %')
    logging.info(f'Finish testing...\n')

2024-02-06 03:12:52 | INFO | Start Higaisa testing...
Making dataset...: 100%|██████████| 95/95 [00:29<00:00,  3.18it/s]
2024-02-06 03:14:48 | INFO | Test Accuracy: 0.0077 %
2024-02-06 03:14:48 | INFO | Finish testing...

2024-02-06 03:14:48 | INFO | Start admin338 testing...
Making dataset...: 100%|██████████| 118/118 [00:48<00:00,  2.44it/s]
2024-02-06 03:17:57 | INFO | Test Accuracy: 0.0146 %
2024-02-06 03:17:57 | INFO | Finish testing...

2024-02-06 03:17:57 | INFO | Start APT28 testing...
Making dataset...: 100%|██████████| 162/162 [05:29<00:00,  2.03s/it]
2024-02-06 03:26:10 | INFO | Test Accuracy: 0.0213 %
2024-02-06 03:26:10 | INFO | Finish testing...

2024-02-06 03:26:10 | INFO | Start FIN7 testing...
Making dataset...: 100%|██████████| 234/234 [38:11<00:00,  9.79s/it]   
2024-02-06 04:09:15 | INFO | Test Accuracy: 0.0196 %
2024-02-06 04:09:15 | INFO | Finish testing...

2024-02-06 04:09:15 | INFO | Start CobaltGroup testing...
Making dataset...: 100%|██████████| 117/117 [00:3