In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import torch
from torch_geometric.data import Data
import os
import torch.nn.functional as F
import json 
import warnings
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
warnings.filterwarnings('ignore')
from torch_geometric.loader import NeighborLoader
import multiprocessing
from elasticsearch import Elasticsearch, helpers
import re
from tqdm import tqdm

import torch_sparse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%matplotlib inline

In [2]:
os.chdir("../../../../Trustwatch")
%pwd

'/home/jovyan/Trustwatch'

In [3]:
def load_dict_from_jsonl(file_path):
    result = {}
    with open(file_path, 'r') as file:
        for line in file:
            data = json.loads(line)
            result.update(data)
    return result

In [4]:
def stitch(data_buffer,file_path):
    
    id_to_type_file = 'e5_data/clearscope_id_to_type.json'
    net2prop_file = 'e5_data/clearscope_net2prop.json' 
    
    id_to_type = load_dict_from_jsonl(id_to_type_file)
    net2prop = load_dict_from_jsonl(net2prop_file)
    info = data_buffer
    
    for i in range(len(info)):
        try:
            typ = id_to_type[info[i]['objectID']]
            info[i]['object'] = typ
            info[i]['actor_type'] = id_to_type[info[i]['actorID']]
            if typ == 'NETFLOW':
                attr = net2prop[info[i]['objectID']]
                info[i]['path'] = attr[0]+' '+attr[1]+' '+attr[2]+' '+attr[3]
        except:
            info[i]['object'] = None
            info[i]['actor_type'] = None
            
    df = pd.DataFrame.from_records(info)
    df = df.dropna()
    df.to_parquet(f"e5_data/{file_path}.parquet", index=False)

In [5]:
def query_elastic(query,file_path):
    
    username = 'elastic'
    password = 'stimulus5affect-roof'
    host = 'http://128.143.69.88:9200'
    index_name = "clearscope*"

    # Initialize Elasticsearch client
    es = Elasticsearch([host], http_auth=(username, password))

    # Test connection and index existence
    if not es.ping():
        print("Elasticsearch cluster is not accessible!")
    else:
        print("Connected to Elasticsearch.")
    if not es.indices.exists(index=index_name):
        print(f"Index {index_name} does not exist.")
    else:
        print(f"Index {index_name} exists.")

    total_docs = es.count(index=index_name, body=query)['count']

    edge_types = set([
        'EVENT_CLOSE',
        'EVENT_OPEN',
        'EVENT_READ',
        'EVENT_WRITE',
        'EVENT_EXECUTE',
        'EVENT_RECVFROM',
        'EVENT_RECVMSG',
        'EVENT_SENDMSG',
        'EVENT_SENDTO',
    ])

    info_buffer = []

    # Start processing documents
    with tqdm(total=total_docs, desc="Processing Documents") as pbar:
        for doc in helpers.scan(es, query=query, index=index_name, size=1000):
            pbar.update(1)

            line = doc['_source']
            str_line = json.dumps(line)

            x = line            

            try:
                action = x['datum']['com.bbn.tc.schema.avro.cdm20.Event']['type']
            except:
                action = ''

            try:
                actor = x['datum']['com.bbn.tc.schema.avro.cdm20.Event']['subject']['com.bbn.tc.schema.avro.cdm20.UUID']
            except:
                actor = ''

            try:
                obj = x['datum']['com.bbn.tc.schema.avro.cdm20.Event']['predicateObject']['com.bbn.tc.schema.avro.cdm20.UUID']
            except:
                obj = ''

            try:
                cmd = x['datum']['com.bbn.tc.schema.avro.cdm20.Event']['properties']['map']['exec']
            except:
                cmd = ''

            try:
                path = x['datum']['com.bbn.tc.schema.avro.cdm20.Event']['predicateObjectPath']['string']
            except:
                path = ''

            try:
                timestampnano = x['datum']['com.bbn.tc.schema.avro.cdm20.Event']['timestampNanos']
                timestamp = x['@timestamp']
            except:
                timestamp = ''
                timestampnano = ''

            if action in edge_types:
                info_data = {'actorID': actor, 'objectID': obj, 'action': action, 'timestampNanos': timestampnano, 'timestamp': timestamp, 'exec': cmd, 'path': path, 'hostid': x['hostId']}
                info_buffer.append(info_data)
    
    stitch(info_buffer,file_path)

In [6]:
import concurrent.futures

def prepare_elastic_dataset():
    # List of tuples containing the queries and their respective indices
    queries = [
        ({"query": {"range": {"@timestamp": {"gte": "2019-05-09T00:00:00.360Z", "lte": "2019-05-09T07:00:00.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}}}, "ben_clearscope"),
        ({"query": {"bool": {"should": [
            {"range": {"@timestamp": {"gte": "2019-05-15T14:07:59.360Z", "lte": "2019-05-15T14:23:00.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-15T14:23:00.360Z", "lte": "2019-05-15T14:38:02.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-15T15:38:59.360Z", "lte": "2019-05-15T15:55:38.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-15T15:55:38.360Z", "lte": "2019-05-15T16:11:27.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-17T14:50:52.360Z", "lte": "2019-05-17T15:06:00.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-17T15:06:00.360Z", "lte": "2019-05-17T15:21:40.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-17T15:21:40.360Z", "lte": "2019-05-17T15:36:41.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-17T15:36:41.360Z", "lte": "2019-05-17T15:51:43.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-17T15:51:43.360Z", "lte": "2019-05-17T16:06:44.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-17T16:21:46.360Z", "lte": "2019-05-17T16:36:47.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-17T16:36:47.360Z", "lte": "2019-05-17T16:51:49.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}}
        ], "minimum_should_match": 1}}}, "mal_clear_gt_ids"),
        ({"query": {"bool": {"should": [
            {"range": {"@timestamp": {"gte": "2019-05-15T14:00:00.360Z", "lte": "2019-05-15T18:00:00.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}},
            {"range": {"@timestamp": {"gte": "2019-05-17T14:00:00.360Z", "lte": "2019-05-17T18:00:00.360Z", "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"}}}
        ], "minimum_should_match": 1}}}, "mal_clearscope")
    ]

    def run_query(query, index):
        # This is a placeholder for your actual query execution function
        query_elastic(query, index)

    # Execute queries in parallel using ThreadPoolExecutor
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Submit all queries to the executor
        futures = [executor.submit(run_query, query, index) for query, index in queries]

        # Wait for all futures to complete
        for future in concurrent.futures.as_completed(futures):
            try:
                # Process results here if `query_elastic` returns anything
                result = future.result()
            except Exception as exc:
                print(f'Generated an exception: {exc}')

In [7]:
#prepare_elastic_dataset()

In [8]:
def prepare_elastic_dataset():

    query ={
        "query": {
            "range": {
            "@timestamp": {
                "gte": "2019-05-08T00:00:00.360Z",
                "lte": "2019-05-09T00:00:00.360Z",
                "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
            }
            }
        }
    }

    query_elastic(query,"ben_clearscope")

    query = {
      "query": {
        "bool": {
          "should": [
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-15T14:07:59.360Z",
                  "lte": "2019-05-15T14:23:00.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-15T14:23:00.360Z",
                  "lte": "2019-05-15T14:38:02.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-15T15:38:59.360Z",
                  "lte": "2019-05-15T15:55:38.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-15T15:55:38.360Z",
                  "lte": "2019-05-15T16:11:27.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-17T14:50:52.360Z",
                  "lte": "2019-05-17T15:06:00.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-17T15:06:00.360Z",
                  "lte": "2019-05-17T15:21:40.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-17T15:21:40.360Z",
                  "lte": "2019-05-17T15:36:41.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-17T15:36:41.360Z",
                  "lte": "2019-05-17T15:51:43.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-17T15:51:43.360Z",
                  "lte": "2019-05-17T16:06:44.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-17T16:21:46.360Z",
                  "lte": "2019-05-17T16:36:47.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-17T16:36:47.360Z",
                  "lte": "2019-05-17T16:51:49.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            }
          ],
          "minimum_should_match": 1
        }
      }
    }
    
    query_elastic(query,"mal_clear_gt_ids")
    
    query = {
      "query": {
        "bool": {
          "should": [
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-15T14:00:00.360Z",
                  "lte": "2019-05-15T16:00:00.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-15T15:00:00.360Z",
                  "lte": "2019-05-15T18:00:00.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            },
            {
              "range": {
                "@timestamp": {
                  "gte": "2019-05-17T14:00:00.360Z",
                  "lte": "2019-05-17T18:00:00.360Z",
                  "format": "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
                }
              }
            }
          ],
          "minimum_should_match": 1
        }
      }
    }
    
    query_elastic(query,"mal_clearscope")

In [9]:
#prepare_elastic_dataset()

In [10]:
def train_test_split():
    df_ben = pd.read_parquet('e5_data/ben_clearscope.parquet')
    grouped = df_ben.groupby('hostid')
    hostdfs = {group: data for group, data in grouped}
    return hostdfs

In [11]:
hostdfs = train_test_split()

## Loading libraries and setting up working directory

In [12]:
'''
Importing some additional libraries
'''
from pprint import pprint
import gzip
from sklearn.manifold import TSNE
import json
import copy
import os

In [13]:
num_of_ctg = 10
learning_rounds = 5
epochs = 10
hosts = list(hostdfs.keys())
TRAIN=True

## Defining functions for loading, cleaning and constructing features from the data

In [14]:
'''
This is the main featurizer. It constructs the graph for the clearscope dataset.

Args:
    df (DataFrame): This is the main dataframe containing all the system events from the clearscope dataset.

return:
    features (list): Contains word2vec encoded feature vectors for each node
    feat_labels (list): Contains label for each node
    edge_index (list): Contains information about edges between nodes in the graph.
    mapp (list): contains id of each node
'''

tokens = ['SUBJECT_PROCESS',
          'FILE_OBJECT_FILE',
          'NETFLOW']

def prepare_graph(df):
    global tokens
    dummies = {token: index for index, token in enumerate(tokens)}
    
    df['actor_label'] = df['actor_type'].map(dummies)
    df['object_label'] = df['object'].map(dummies)
    
    nodes = {}
    labels = {}
    for col in ['actorID', 'objectID']:
        unique_ids = df[col].unique()
        for uid in unique_ids:
            nodes[uid] = []
        if col == 'actorID':
            labels.update(df.set_index('actorID')['actor_label'].to_dict())
        else:
            labels.update(df.set_index('objectID')['object_label'].to_dict())
    
    for _, row in df.iterrows():
        nodes[row['actorID']].extend([row['exec'], row['action']])
        nodes[row['objectID']].extend([row['exec'], row['action']])
        if row['path'] != '':
            nodes[row['actorID']].append(row['path'])
            nodes[row['objectID']].append(row['path'])
    
    edges = list(zip(df['actorID'], df['objectID']))

    mapp = list(nodes.keys())
    features = [nodes[node_id] for node_id in mapp]
    feat_labels = [labels[node_id] for node_id in mapp]
    edge_index = [[], []]
    index_map = {node_id: index for index, node_id in enumerate(mapp)}
    
    for src, dst in edges:
        edge_index[0].append(index_map[src])
        edge_index[1].append(index_map[dst])
    
    all_procs = list(df['actorID'].unique())
    idx_to_proc = {index: proc for index, proc in enumerate(all_procs)}

    return features, feat_labels, edge_index, mapp, all_procs, idx_to_proc

In [15]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv, GATConv
import torch.nn.functional as F
import torch.nn as nn

class GCN(torch.nn.Module):
    def __init__(self,in_channel,out_channel):
        super(GCN, self).__init__()
        self.conv1 = SAGEConv(in_channel, 32, normalize=True)
        self.conv2 = SAGEConv(32, out_channel, normalize=True)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
    
        x = self.encode(x, edge_index)
        return F.softmax(x, dim=1)
    
    def encode(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        
        x = self.conv1(x, edge_index)
        x = F.tanh(x)
        x = self.conv2(x, edge_index)
        return x
    
    def freeze_conv_layers(self):
        for param in self.conv1.parameters():
            param.requires_grad = False
        for param in self.conv2.parameters():
            param.requires_grad = False

In [16]:
'''
This function helps visualize the output of the model.
'''
def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

In [17]:
def combine_word2vec_models(models):
    # Convert the first model's vectors into a dictionary
    unified_dict = {word: models[0].wv[word] for word in models[0].wv.index_to_key}

    # Iterate through the remaining models
    for model in models[1:]:
        model_dict = {word: model.wv[word] for word in model.wv.index_to_key}

        # Iterate through words in the current model
        for word, vector in model_dict.items():
            if word in unified_dict:
                # Average the vectors for overlapping words
                unified_dict[word] = (unified_dict[word] + vector) / 2.0
            else:
                # Add unique words directly
                unified_dict[word] = vector
                
    unified_dict = {word: vector.tolist() if isinstance(vector, np.ndarray) else vector for word, vector in unified_dict.items()}

    return unified_dict

In [18]:
from gensim.models.callbacks import CallbackAny2Vec
import gensim
from gensim.models import Word2Vec
from multiprocessing import Pool
from itertools import compress
from tqdm import tqdm
import time

In [19]:
class EpochSaver(CallbackAny2Vec):
    '''Callback to save model after each epoch.'''

    def __init__(self,filename):
        self.epoch = 0
        self.filename = filename

    def on_epoch_end(self, model):
        model.save(self.filename)
        self.epoch += 1

In [20]:
class EpochLogger(CallbackAny2Vec):
    '''Callback to log information about training'''

    def __init__(self):
        self.epoch = 0

    def on_epoch_begin(self, model):
        print("Epoch #{} start".format(self.epoch))

    def on_epoch_end(self, model):
        print("Epoch #{} end".format(self.epoch))
        self.epoch += 1

In [21]:
def train_word2vec_models():
    global hosts,hostdfs
    
    for h in hosts:
        print("Running host:",h)
        logger = EpochLogger()
        saver = EpochSaver(f"Content_FL_Exp/{h}.model")

        df = hostdfs[h]
        df = df[df['actor_type'] == 'SUBJECT_PROCESS'] 
        df = df[df['object'].isin(tokens)] 

        phrases,feat_labels,edge_index,mapp,all_procs,idx_to_proc = prepare_graph(df)
        word2vec = Word2Vec(sentences=phrases, vector_size=30, window=5, min_count=1, workers=5,epochs=100,callbacks=[saver,logger])

In [22]:
if TRAIN:
    word_models = []
    for m in hosts:
        word2vec = Word2Vec.load(f"Content_FL_Exp/{m}.model")
        word_models.append(word2vec)
        
    global_word = combine_word2vec_models(word_models)

    with open('Content_FL_Exp/e5_clearscope_word2vec_global.json', 'w') as json_file:
        json.dump(global_word, json_file)

OSError: [Errno 30] Read-only file system: 'Content_FL_Exp/e5_clearscope_word2vec_global.json'

In [23]:
def load_word_model():
    with open('Content_FL_Exp/e5_clearscope_word2vec_global.json', 'r') as json_file:
        loaded_dict = json.load(json_file)

    converted_dict = {word: np.array(vector) for word, vector in loaded_dict.items()}
    return converted_dict

In [24]:
'''
Defining the train and test function in this cell 
'''
from sklearn.utils import class_weight
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

In [25]:
from collections import Counter
word2vec = load_word_model()

def infer(doc):
    global word2vec
    temp = dict(Counter(doc))
    emb = np.zeros(30)
    count = 0
    for k,v in temp.items():
        if k in word2vec:
            emb = emb + word2vec[k]*v
            count = count + 1
    emb = emb / count
    return emb

In [26]:
def init_gnns():
    global num_of_ctg,tokens
    n = num_of_ctg 
    gnn_models = []
    for i in range(n):
        m = GCN(30,len(tokens)).to(device)
        gnn_models.append(m)
    return gnn_models

In [27]:
def define_categories(pids):
    global num_of_ctg
    n = num_of_ctg - 1
    ctg = set(pids)
    ctg = list(ctg)
    k, m = divmod(len(ctg), n)
    return [set(ctg[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n)]

In [28]:
def map_pids_to_category_indices(pids, categories):
    pid_to_category_index = {}
    
    for pid in pids:
        for category_index, category_set in enumerate(categories):
            if pid in category_set:
                pid_to_category_index[pid] = category_index 
                break 
    
    return pid_to_category_index

In [29]:
from torch.nn import CrossEntropyLoss
from sklearn.utils import class_weight
import copy

templates = init_gnns()

def train_gnn_func(nodes,labels,edges,mapp,pids,idx_to_pid):
    
    global categories ,epochs
    
    pid_to_gnn_index = map_pids_to_category_indices(pids, categories)
    
    set_pids = set(pids)

    proc_index = [i for i in range(len(mapp)) if mapp[i] in set_pids]

    train_splits = [[] for _ in range(len(categories))]

    for i in proc_index:
        pname = idx_to_pid[str(i)]
        split_indx = pid_to_gnn_index[pname]
        train_splits[split_indx].append(int(i))
        
    local_models = [copy.deepcopy(x) for x in templates]
    
    for i in range(len(local_models)-1):
            
        if len(train_splits[i]) == 0:
            local_models[i] = None
        else:
            if f"target_e5_clearscope_global{i}.pth" in os.listdir("Content_FL_Exp"):

                #must update the next line to torch.load(...,map_location=torch.device('cpu') if running on beryl
                
                local_models[i].load_state_dict(torch.load(f"Content_FL_Exp/target_e5_clearscope_global{i}.pth"))

            optimizer = torch.optim.Adam(local_models[i].parameters(), lr=0.01, weight_decay=5e-4)
            criterion = CrossEntropyLoss()

            graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
            mask = torch.tensor([False]*graph.num_nodes, dtype=torch.bool)
            mask[train_splits[i]] = True
            
            def get_neighbors(edge_index, nodes):
                neighbors = []
                for node in nodes:
                    mask = edge_index[0] == node
                    neighbors.extend(edge_index[1, mask].tolist())
                return torch.tensor(list(set(neighbors)), dtype=torch.long)

            one_hop_neighbors = get_neighbors(graph.edge_index, train_splits[i])
            two_hop_neighbors = get_neighbors(graph.edge_index, one_hop_neighbors)
            two_hop_neighbors = two_hop_neighbors[~mask[two_hop_neighbors]]
            mask[two_hop_neighbors] = True
            
            for epoch in range(epochs):
                print(f'Training GNN Category {i} Model for Epoch {epoch}')

                loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000,input_nodes=mask)
                total_loss = 0
                for subg in loader:
                    local_models[i].train()
                    optimizer.zero_grad() 
                    out = local_models[i](subg.x, subg.edge_index) 
                    loss = criterion(out, subg.y) 
                    loss.backward() 
                    optimizer.step()      
                    total_loss += loss.item() * subg.batch_size
                print("Loss: ", total_loss / mask.sum().item(), '\n')
    
    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
    optimizer = torch.optim.Adam(local_models[-1].parameters(), lr=0.01, weight_decay=5e-4)
    criterion = CrossEntropyLoss()
    
    for epoch in range(epochs):
        print(f'Training Catch all GNN Category Model for Epoch {epoch}')    
        local_models[-1].train()
        optimizer.zero_grad() 
        out = local_models[-1](graph.x, graph.edge_index) 
        loss = criterion(out, graph.y) 
        loss.backward() 
        optimizer.step()      
        print(f"Epoch: {epoch}, Loss: {loss.item()}")

    return local_models

In [30]:
procs_total = []
data_cache = {}
categories = None

def load_clients_data():
    
    global data_cache,categories,procs_total,tokens,hostdfs,hosts
    
    for name in hosts:
        df = hostdfs[name]
        df = df[df['actor_type'] == 'SUBJECT_PROCESS'] 
        df = df[df['object'].isin(tokens)]    

        docs,labels,edges,mapp,pids,idx_to_pid = prepare_graph(df)
        data_cache[name] = [docs,labels,edges,mapp,pids,idx_to_pid]
        procs_total = procs_total + pids

    categories = define_categories(procs_total)

In [31]:
def client_handling_loop(client_id):    
    print(f"Running Setup on Client {client_id} \n")
    
    docs,labels,edges,mapp,pids,idx_to_pid = data_cache[client_id]
    
    nodes_feat = []
    for x in docs:
        nodes_feat.append(infer(x)) 
        
    trained_local_models = train_gnn_func(nodes_feat,labels,edges,mapp,pids,idx_to_pid)
    return trained_local_models

In [40]:
def server_aggregate(all_models):
    global_models = copy.deepcopy(templates)
    
    for l in range(len(all_models)):
        
        current_models = all_models[l]
        current_models = [x for x in current_models if x != None]
        
        if not len(current_models) == 0:
        
            global_dict = global_models[l].state_dict()

            for k in global_dict.keys():
                param_list = [current_models[i].state_dict()[k] for i in range(len(current_models))]
                global_dict[k] = torch.stack(param_list, 0).mean(0)

            global_models[l].load_state_dict(global_dict)
            torch.save(global_models[l].state_dict(), f"Content_FL_Exp/target_e5_clearscope_global{l}.pth")
                   
    return global_models

In [33]:
def thread_local_gnns(c):
    local_gnns = client_handling_loop(c)
    return local_gnns

In [34]:
import random
def perform_federated_learning(n_clients):

    client_models = []
    
    with concurrent.futures.ThreadPoolExecutor() as executer:
        futures = [executer.submit(thread_local_gnns, c) for c in n_clients]
        concurrent.futures.wait(futures) 
    
    for future in concurrent.futures.as_completed(futures):
        client_models.append(future.result())
        
    return client_models

In [35]:
if TRAIN:   
    !rm Content_FL_Exp/target_e5_clearscope*.pth

rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global0.pth': Read-only file system
rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global1.pth': Read-only file system
rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global2.pth': Read-only file system
rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global3.pth': Read-only file system
rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global4.pth': Read-only file system
rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global5.pth': Read-only file system
rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global6.pth': Read-only file system
rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global7.pth': Read-only file system
rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global8.pth': Read-only file system
rm: cannot remove 'Content_FL_Exp/target_e5_clearscope_global9.pth': Read-only file system


In [36]:
with open('Content_FL_Exp/e5_clearscope_ensemble_ben.json', 'r') as f:
    data_cache = json.load(f)

proc_total = []
for x in hosts:
    proc_total = proc_total + data_cache[x][-2]
    
categories = define_categories(proc_total)

In [37]:
#load_clients_data()

In [41]:
if TRAIN:
    for r in range(learning_rounds):
        print(f"Federated Learning Round Number: {r}\n")
        client_models = perform_federated_learning(hosts)
        arranged_models =  [list(group) for group in zip(*client_models)]
        global_models = server_aggregate(arranged_models)

True
Federated Learning Round Number: 0

Running Setup on Client 055F80A9-714A-5BFA-259A-E6046D38EA38 

Running Setup on Client 54FF20FC-635E-6455-F04F-EA4FA27EBC1E 

Running Setup on Client 860178F8-0FE9-66CC-8EE2-F6BBD1A59DAB 

Training GNN Category 0 Model for Epoch 0
Training GNN Category 0 Model for Epoch 0
Training GNN Category 0 Model for Epoch 0
Loss:  0.8348554372787476 

Training GNN Category 0 Model for Epoch 1
Loss:  0.8321799039840698 

Training GNN Category 0 Model for Epoch 1
Loss:  0.9363691806793213 

Training GNN Category 0 Model for Epoch 1
Loss:  0.844879686832428 

Training GNN Category 0 Model for Epoch 2
Loss:  0.9342685341835022 

Training GNN Category 0 Model for Epoch 2
Loss:  0.8294360041618347 

Training GNN Category 0 Model for Epoch 3
Loss:  0.9523756504058838 

Training GNN Category 0 Model for Epoch 2
Loss:  0.8296067118644714 

Training GNN Category 0 Model for Epoch 4
Loss:  0.9333190321922302 

Training GNN Category 0 Model for Epoch 3
Loss:  0.846058

## Evaluation of the trained GNN model starts here

In [42]:
'''
This function is used for constructing neighborhood around a given 
set of nodes for backwards or forward tracking
'''
from itertools import compress
from torch_geometric import utils

def construct_neighborhood(ids,mapp,edges,hops):
    if hops == 0:
        return set()
    else:
        neighbors = set()
        for i in range(len(edges[0])):
            if mapp[edges[0][i]] in ids:
                neighbors.add(mapp[edges[1][i]])
            if mapp[edges[1][i]] in ids:
                neighbors.add(mapp[edges[0][i]])
        return neighbors.union( construct_neighborhood(neighbors,mapp,edges,hops-1) )

In [43]:
'''
This function logs the evaluation metrics.
'''

def helper(MP,all_pids,GP,edges,mapp):

    TP = MP.intersection(GP)  
    FP = MP - GP              
    FN = GP - MP              
    TN = all_pids - (GP | MP)
    
    two_hop_gp = construct_neighborhood(GP,mapp,edges,2)
    two_hop_tp = construct_neighborhood(TP,mapp,edges,2)
    FPL = FP - two_hop_gp
    TPL = TP.union(FN.intersection(two_hop_tp))
    FN = FN - two_hop_tp
    
    alerts = TP.union(FP)

    TP,FP,FN,TN = len(TPL),len(FPL),len(FN),len(TN)
    
    FPR = FP / (FP+TN)
    TPR = TP / (TP+FN)

    print(f"Number of True Positives: {TP}")
    print(f"Number of Fasle Positives: {FP}")
    print(f"Number of False Negatives: {FN}")

    prec = TP / (TP + FP)
    print(f"Precision: {prec}")

    rec = TP / (TP + FN)
    print(f"Recall: {rec}")

    fscore = (2*prec*rec) / (prec + rec)
    print(f"Fscore: {fscore}\n")
    
    return TPL,FPL

In [44]:
def generate_groundtruth():
    
    df = pd.read_parquet('e5_data/mal_clear_gt_ids.parquet')
    
    df = df[df['actor_type'] == 'SUBJECT_PROCESS'] 
    df = df[df['object'].isin(tokens)]
    
    unique_actorIDs = set(df['actorID'].unique())
    unique_objectIDs = set(df['objectID'].unique())

    unified_set = unique_actorIDs.union(unique_objectIDs)
    
    return unified_set

In [45]:
GT_mal = generate_groundtruth()

In [46]:
def load_data_test():
    
    df = pd.read_parquet('e5_data/mal_clearscope.parquet')
    
    df = df[df['actor_type'] == 'SUBJECT_PROCESS'] 
    df = df[df['object'].isin(tokens)]

    docs,labels,edges,mapp,pids,idx_to_pid = prepare_graph(df)
    return [docs,labels,edges,mapp,pids,idx_to_pid]

In [47]:
#data_mal = load_data_test()

In [48]:
with open('e5_data/clearscope_mal_proc.json', 'r') as file:
    data_mal =  json.load(file)

In [49]:
def run_evaluation(thresh):    
    global word2vec,tokens,GT_mal
            
    phrases,labels,edges,mapp,pids,idx_to_pid = data_mal

    model = GCN(30,len(tokens)).to(device)
    word2vec = load_word_model()

    nodes = [infer(x) for x in phrases]
    nodes = np.array(nodes)  

    all_ids = set(mapp)
        
    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))

    flag = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)


    def model_evaluation_loop(m_n):
        if f"target_e5_clearscope_global{m_n}.pth" in os.listdir("Content_FL_Exp"): 
            model.load_state_dict(torch.load(f"Content_FL_Exp/target_e5_clearscope_global{m_n}.pth",map_location=torch.device('cpu')))
            
        model.eval()
        out = model(graph.x, graph.edge_index)

        sorted, indices = out.sort(dim=1,descending=True)
        conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]
        conf = (conf - conf.min()) / conf.max()

        pred = indices[:,0]
        cond = (pred == graph.y) & (conf >= thresh)
        flag[cond] = torch.logical_and(flag[cond], torch.tensor([False]*len(flag[cond]), dtype=torch.bool))
        
    with concurrent.futures.ThreadPoolExecutor() as executer:
        futures = [executer.submit(model_evaluation_loop, m_n) for m_n in range(num_of_ctg)]
        concurrent.futures.wait(futures) 
    
        

    index = utils.mask_to_index(flag).tolist()
    ids = set([mapp[x] for x in index])
    metrics = helper(set(ids),set(all_ids),GT_mal,edges,mapp)

In [50]:
run_evaluation(0.7)

Number of True Positives: 14067
Number of Fasle Positives: 238
Number of False Negatives: 0
Precision: 0.9833624606780846
Recall: 1.0
Fscore: 0.9916114479063866

