In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import torch
from torch_geometric.data import Data
import torch.nn.functional as F
import warnings
import os
from torch_geometric.loader import NeighborLoader
import random

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
num_of_ctg = 10
learning_rounds = 10
epochs = 20
hids = [random.randint(100, 900) for _ in range(10)]
hosts = [f'SysClient0{x}.systemia.com.txt' for x in hids]

In [None]:
from pprint import pprint
import json
import copy

import gensim
from gensim.models import Word2Vec
from multiprocessing import Pool
from itertools import compress
from tqdm import tqdm

In [None]:
import gensim
from collections import Counter
from gensim.models import Word2Vec
from multiprocessing import Pool
from itertools import compress
from tqdm import tqdm
import time

def infer(doc,word2vec):  
    word_emb = []
    for word in doc:
        if word in word2vec.wv:
            word_emb.append(word2vec.wv[word])
            
    if len(word_emb) == 0:
        return np.zeros(20)

    out_emb = torch.tensor(word_emb,dtype=torch.float)
    out_emb = out_emb.detach().cpu().numpy()
    out_emb = np.mean(out_emb,axis=0)
    return out_emb

In [None]:
def preprocess(data):
    new_data = {}
    for x in data:
        check1 = x['object'] in ['PROCESS','FILE','FLOW','MODULE']
        check2 = not (x['action'] in ['START','TERMINATE'])
        check3 = x['actorID'] != x['objectID']
        key = (x['action'],x['actorID'],x['objectID'],x['object'],x['pid'],x['ppid'])
        if check1 and check2 and check3:
            new_data[key] = x
    return list(new_data.values())

In [None]:
def Extract_Semantic_Info(event):
    object_type = event['object']
    properties = event['properties']

    label_mapping = {
        "PROCESS": ('parent_image_path', 'image_path'),
        "FILE": ('image_path', 'file_path'),
        "MODULE": ('image_path', 'module_path'),
        "FLOW": ('image_path', 'dest_ip', 'dest_port')
    }

    label_keys = label_mapping.get(object_type, None)
    if label_keys:
        labels = [properties.get(key) for key in label_keys]
        if all(labels):
            event["actorname"], event["objectname"] = labels[0], ' '.join(labels[1:])
            return event
    return None

In [None]:
def describe(x):
    action = x["action"]
    props = x['properties']
    typ = x['object']

    phrase = ''
    try:
        if typ == 'PROCESS':
            phrase = f"{props['parent_image_path']} {action} {props['image_path']} {props['command_line']}" 

        elif typ == 'FILE':
            phrase = f"{props['image_path']} {action} {props['file_path']}"    

        elif typ == 'FLOW':
            phrase = f"{props['image_path']} {action}  {props['dest_ip']} {props['dest_port']} {props['direction']}"    

        else:
            phrase = f"{props['image_path']} {action} {props['module_path']}"
    except:
        phrase = ''
  
    return phrase.split(' ') 

In [None]:
def transform(text):
    text = [event for event in (Extract_Semantic_Info(x) for x in text) if event]
    data = preprocess(text)

    temp = [describe(x) for x in data]
    temp = [x for x in temp if len(x) != 0]

    for i in range(len(data)):
        data[i]['phrase'] = temp[i]
        try:
            data[i]['proc_name'] = data[i]['properties']['image_path']
        except:
            data[i]['proc_name'] = ''

    df = pd.DataFrame.from_dict(data)
    df['timestamp'] = df['timestamp'].str[:-6]
    df['timestamp'] = pd.to_datetime(df['timestamp'],infer_datetime_format=True)
    df.sort_values(by='timestamp', ascending=True,inplace=True)

    return df

def load_data(dataset_id):
    f = open(f"hosts/{dataset_id}")
    content = [json.loads(line) for line in f]
    return prepare_graph(transform(content))

In [None]:
def prepare_graph(df):
    nodes = {}
    labels = {}
    edges = []
    proc = {}
    dummies = {'PROCESS':0,'FLOW':1,'FILE':2,'MODULE':3}

    for i in range(len(df)):
        x = df.iloc[i]

        actorid = x['actorID']
        if not (actorid in nodes):
            nodes[actorid] = []
        nodes[actorid] += x['phrase']
        labels[actorid] = dummies['PROCESS']

        objectid = x["objectID"]
        if not (objectid in nodes):
            nodes[objectid] = []
        nodes[objectid] += x['phrase']
        labels[objectid] = dummies[x['object']]
        
        edges.append(( actorid, objectid, x['action'] ))
        
        proc[actorid] = x['proc_name']

    features = []
    feat_labels = []
    edge_index = [[],[]]
    index  = {}
    mapp = []
    
    all_procs = set()
              
    for k,v in nodes.items():
        features.append(v)
        feat_labels.append(labels[k])
        index[k] = len(features) - 1
        mapp.append(k)
        
        if k in proc:
            all_procs.add(proc[k])

    for x in edges:
        src = index[x[0]]
        dst = index[x[1]]
    
        edge_index[0].append(src)
        edge_index[1].append(dst)    
    
    idx_to_proc = {}
    for i in range(len(mapp)):
        if mapp[i] in proc:
            idx_to_proc[i] = proc[mapp[i]]
            
    all_procs = list(all_procs)
    
    return features,feat_labels,edge_index,mapp,all_procs,idx_to_proc

In [None]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv, GATConv
import torch.nn.functional as F
import torch.nn as nn

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = SAGEConv(20, 32, normalize=True)
        self.conv2 = SAGEConv(32, 20, normalize=True)
        self.linear = nn.Linear(in_features=20, out_features=4)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
    
        x = self.encode(x, edge_index)
        x = self.linear(x)
        return F.softmax(x, dim=1)
    
    def encode(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x
    
    def freeze_conv_layers(self):
        for param in self.conv1.parameters():
            param.requires_grad = False
        for param in self.conv2.parameters():
            param.requires_grad = False

In [None]:
from gensim.models.callbacks import CallbackAny2Vec
import gensim
from gensim.models import Word2Vec
from multiprocessing import Pool
from itertools import compress
from tqdm import tqdm
import time

class EpochSaver(CallbackAny2Vec):

    def __init__(self,client_id):
        self.epoch = 0
        self.cid = client_id

    def on_epoch_end(self, model):
        model.save(f"{self.cid}.model")
        self.epoch += 1
        
class EpochLogger(CallbackAny2Vec):

    def __init__(self):
        self.epoch = 0

    def on_epoch_begin(self, model):
        pass

    def on_epoch_end(self, model):
        self.epoch += 1
        
def train_word2vec_func(docs,client_id):
    logger = EpochLogger()
    saver = EpochSaver(client_id)
    word2vec = Word2Vec(sentences=docs, vector_size=20, window=5, min_count=1,workers=5,epochs=100,callbacks=[saver,logger])

In [None]:
def init_gnns():
    global num_of_ctg
    n = num_of_ctg
    gnn_models = []
    for i in range(n):
        m = GCN().to(device)
        gnn_models.append(m)
    return gnn_models

In [None]:
def define_categories(pids):
    global num_of_ctg
    n = num_of_ctg - 1
    ctg = set(pids)
    ctg = list(ctg)
    k, m = divmod(len(ctg), n)
    return [set(ctg[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n)]

In [None]:
def map_pids_to_category_indices(pids, categories):
    pid_to_category_index = {}
    
    for pid in pids:
        for category_index, category_set in enumerate(categories):
            if pid in category_set:
                pid_to_category_index[pid] = category_index
                break 
    
    return pid_to_category_index

In [None]:
procs_total = []
data_cache = {}
categories = None

def generate_key():
    return Fernet.generate_key()

def encrypt_data(data, key):
    fernet = Fernet(key)
    return [fernet.encrypt(str(d).encode()) for d in data]

def decrypt_data(nested_data, key):
    fernet = Fernet(key)
    return [[fernet.decrypt(d).decode() for d in inner_list] for inner_list in nested_data]

def load_clients_data(client_ids):
    
    global data_cache,categories,procs_total
    key = generate_key()

    for x in client_ids:
        docs,labels,edges,mapp,pids,idx_to_pid = load_data(x)
        data_cache[x] = [docs,labels,edges,mapp,pids,idx_to_pid]
        procs_total = procs_total + pids
        
    encrypted_procs_total = encrypt_data(procs_total, key)
    categories = define_categories(encrypted_procs_total)
    categories = decrypt_data(categories, key)

In [None]:
from torch.nn import CrossEntropyLoss
import copy

templates = init_gnns()

def train_gnn_func(nodes,labels,edges,mapp,pids,idx_to_pid):
    
    global categories ,epochs
    
    pid_to_gnn_index = map_pids_to_category_indices(pids, categories)
    
    set_pids = set(pids)

    proc_index = list(idx_to_pid.keys())

    train_splits = [[] for _ in range(len(categories))]
    
    for i in proc_index:
        pname = idx_to_pid[str(i)]
        split_indx = pid_to_gnn_index[pname]
        train_splits[split_indx].append(int(i))
        
    local_models = [copy.deepcopy(x) for x in templates]
    
    for i in range(len(local_models)-1):
            
        if len(train_splits[i]) == 0:
            local_models[i] = None
        else:
            if f"global{i}.pth" in os.listdir():
                local_models[i].load_state_dict(torch.load(f"global{i}.pth"))

            optimizer = torch.optim.Adam(local_models[i].parameters(), lr=0.01, weight_decay=5e-4)
            criterion = CrossEntropyLoss()

            graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
            mask = torch.tensor([False]*graph.num_nodes, dtype=torch.bool)
            mask[train_splits[i]] = True
            
            def get_neighbors(edge_index, nodes):
                neighbors = []
                for node in nodes:
                    mask = edge_index[0] == node
                    neighbors.extend(edge_index[1, mask].tolist())
                return torch.tensor(list(set(neighbors)), dtype=torch.long)

            one_hop_neighbors = get_neighbors(graph.edge_index, train_splits[i])
            two_hop_neighbors = get_neighbors(graph.edge_index, one_hop_neighbors)
            two_hop_neighbors = two_hop_neighbors[~mask[two_hop_neighbors]]
            mask[two_hop_neighbors] = True
            
            for epoch in range(epochs):
                loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000,input_nodes=mask)
                total_loss = 0
                for subg in loader:
                    local_models[i].train()
                    optimizer.zero_grad() 
                    out = local_models[i](subg.x, subg.edge_index) 
                    loss = criterion(out, subg.y) 
                    loss.backward() 
                    optimizer.step()      
                    total_loss += loss.item() * subg.batch_size
                print("Loss: ", total_loss / mask.sum().item(), '\n')
    
    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
    optimizer = torch.optim.Adam(local_models[-1].parameters(), lr=0.01, weight_decay=5e-4)
    criterion = CrossEntropyLoss()
    
    for epoch in range(epochs):
        local_models[-1].train()
        optimizer.zero_grad() 
        out = local_models[-1](graph.x, graph.edge_index) 
        loss = criterion(out, graph.y) 
        loss.backward() 
        optimizer.step()      
        print(f"Epoch: {epoch}, Loss: {loss.item()}")

    return local_models

In [None]:
from gensim.models.callbacks import CallbackAny2Vec
import gensim
from gensim.models import Word2Vec
from multiprocessing import Pool
from itertools import compress
from tqdm import tqdm
import time

In [None]:
class EpochSaver(CallbackAny2Vec):
    def __init__(self,filename):
        self.epoch = 0
        self.filename = filename

    def on_epoch_end(self, model):
        model.save(self.filename)
        self.epoch += 1

In [None]:
class EpochLogger(CallbackAny2Vec):
    def __init__(self):
        self.epoch = 0

    def on_epoch_begin(self, model):
        print("Epoch #{} start".format(self.epoch))

    def on_epoch_end(self, model):
        print("Epoch #{} end".format(self.epoch))
        self.epoch += 1

In [None]:
def train_client_word2vec_model(ids,data):
    
    for h in ids:
        logger = EpochLogger()
        saver = EpochSaver(f"{h}.model") 

        phrases,feat_labels,edge_index,mapp,all_procs,idx_to_proc = data[h]
        word2vec = Word2Vec(sentences=phrases, vector_size=30, window=5, min_count=1, workers=5,epochs=100,callbacks=[saver,logger])

In [None]:
from cryptography.fernet import Fernet
import base64

def generate_key():
    return Fernet.generate_key()

def encrypt_word2vec_model(word2vec_model, encryption_key):
    f = Fernet(encryption_key)
    vector_size = word2vec_model.vector_size  
    encrypted_model = gensim.models.Word2Vec(vector_size=vector_size, min_count=1)  
    
    for word in word2vec_model.wv.index_to_key:
        vector = word2vec_model.wv.get_vector(word)
        encrypted_word = f.encrypt(word.encode()).decode()
        encrypted_model.wv[encrypted_word] = vector
    return encrypted_model

def decrypt_word2vec_model(word2vec_model, encryption_key):
    f = Fernet(encryption_key)
    vector_size = word2vec_model.vector_size  
    decrypted_model = gensim.models.Word2Vec(vector_size=vector_size, min_count=1)  
    
    for word in word2vec_model.wv.index_to_key:
        vector = word2vec_model.wv.get_vector(word)
        decrypted_word = f.decrypt(word.encode()).decode()
        decrypted_model.wv[decrypted_word] = vector
    
    return decrypted_model

In [None]:
key = generate_key()

encrypted_word_models = []
for m in hosts:
    word2vec = Word2Vec.load(f"{m}.model")
    encrypted_model = encrypt_word2vec_model(word2vec,key)
    word_models.append(encrypted_model)
    
encrypted_global_word = combine_word2vec_models(encrypted_word_models)

global_word = decrypt_word2vec_model(encrypted_word2vec_model,key)
global_word.save("unified_word2vec.model")

In [None]:
def client_handling_loop(client_id):
    
    docs,labels,edges,mapp,pids,idx_to_pid = data_cache[client_id]
    
    nodes_feat = []
    word2vec = Word2Vec.load(f"unified_word2vec.model")
    for x in docs:
        nodes_feat.append( infer(x,word2vec) ) 
        
    trained_local_models = train_gnn_func(nodes_feat,labels,edges,mapp,pids,idx_to_pid)
    return trained_local_models

In [None]:
def perform_federated_learning(n_clients):
    client_models = []
    for c in n_clients:
        local_gnns = client_handling_loop(c)
        client_models.append(local_gnns)
    return client_models

In [None]:
def server_aggregate(all_models):
    global_models = copy.deepcopy(templates)
    
    for l in range(len(all_models)):
        
        current_models = all_models[l]
        current_models = [x for x in current_models if x != None]
        
        if not len(current_models) == 0:
        
            global_dict = global_models[l].state_dict()

            for k in global_dict.keys():
                param_list = [current_models[i].state_dict()[k] for i in range(len(current_models))]
                global_dict[k] = torch.stack(param_list, 0).mean(0)

            global_models[l].load_state_dict(global_dict)
            torch.save(global_models[l].state_dict(), f"global{l}.pth")
                   
    return global_models

In [None]:
from itertools import compress

In [None]:
def helper(MP,acts,objs,GP,edges,mapp):

    all_pids = acts.union(objs)
    GN = all_pids - GP
    MN = all_pids - MP

    TP = MP.intersection(GP)
    FP = MP.intersection(GN)
    FN = MN.intersection(GP)
    
    two_hop_gp = construct_neighborhood(GP,mapp,edges,2)
    two_hop_tp = construct_neighborhood(TP,mapp,edges,2)
    FP = FP - two_hop_gp
    TP = TP.union(FN.intersection(two_hop_tp))
    FN = FN - two_hop_tp

    TP,FPC,FN = len(TP),len(FP),len(FN)
    
    TN = (len(acts) + len(objs)) - TP - FPC - FN
    
    FPR = FPC / (FPC+TN)
    TPR = TP / (TP+FN)
    
    return TP,FPC,FN,TN,FP 

In [None]:
from itertools import compress
from torch_geometric import utils

def construct_neighborhood(ids,mapp,edges,hops):
    if hops == 0:
        return set()
    else:
        neighbors = set()
        for i in range(len(edges[0])):
            if mapp[edges[0][i]] in ids:
                neighbors.add(mapp[edges[1][i]])
            if mapp[edges[1][i]] in ids:
                neighbors.add(mapp[edges[0][i]])
        return neighbors.union( construct_neighborhood(neighbors,mapp,edges,hops-1) )

In [None]:
load_clients_data(hosts)

In [None]:
for r in range(learning_rounds):
    client_models = perform_federated_learning(hosts)
    arranged_models =  [list(group) for group in zip(*client_models)]
    global_models = server_aggregate(arranged_models)

In [None]:
data_cache_mal = {}

def load_test_data():
    global data_cache_mal
        
    for x in ['201','501','051']:
        path = f"eval_data/SysClient0{x}.systemia.com.txt"
        f = open(path)
        content = [json.loads(line) for line in f]
        docs,labels,edges,mapp,pids,idx_to_pid = prepare_graph(transform(content))
        data_cache_mal[x] = [docs,labels,edges,mapp,pids,idx_to_pid]

In [None]:
load_data_test()

In [None]:
TP,FP,FN,TN = 0,0,0,0
FPL = set()
for data_id in ['201','501','051']:

    docs,labels,edges,mapp,pids,idx_to_pid = data_cache_mal[data_id]

    nodes_feat = []
    word2vec = Word2Vec.load(f"unified_word2vec.model")
    for x in docs:
        nodes_feat.append( infer(x,word2vec) ) 

    with open(f"gt_{data_id}.json", "r") as json_file:
        gt,acts,objs = json.load(json_file)  

    gt,acts,objs = set(gt),set(acts),set(objs)

    graph = Data(x=torch.tensor(nodes_feat,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))

    flag = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)

    model = GCN().to(device)

    for m_n in range(num_of_ctg):
        if f"global{m_n}.pth" in os.listdir(""): 
            model.load_state_dict(torch.load(f"global{m_n}.pth"))
            
        model.eval()
        out = model(graph.x, graph.edge_index)

        sorted, indices = out.sort(dim=1,descending=True)
        conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]
        conf = (conf - conf.min()) / conf.max()

        pred = indices[:,0]
        cond = (pred == graph.y)
        flag[cond] = torch.logical_and(flag[cond], torch.tensor([False]*len(flag[cond]), dtype=torch.bool))

    index = utils.mask_to_index(flag).tolist()
    ids = set([mapp[x] for x in index])

    metrics = helper(set(ids),acts,objs,gt,edges,mapp) 
        
    fp = [i for i in range(len(mapp)) if mapp[i] in metrics[4] and labels[i] in [0,1,2]]
          
    TP = TP + metrics[0]
    FP = FP + len(fp)
    FN = FN + metrics[2]
    TN = TN + metrics[3]

print(f"Number of True Positives: {TP}")
print(f"Number of Fasle Positives: {FP}")
print(f"Number of False Negatives: {FN}")
print(f"Number of True Negatives: {TN}\n")

prec = TP / (TP + FP)
print(f"Precision: {prec}")

rec = TP / (TP + FN)
print(f"Recall: {rec}")

fscore = (2*prec*rec) / (prec + rec)
print(f"Fscore: {fscore}\n")