In [16]:
import requests
from tqdm import tqdm

def download_atlasv2():
    url = "https://public.boxcloud.com/d/1/b1!DF-5pZffA7zR1_FcUVCjx4PJaM3r8Lx6kBOjMXR9DG-geexQh9pfYrZNUzBLVe5vu620bSdhYCFpRhlwpF8DRlkV1VXDUIytUAH0lEvnJsN2Gzu4IVGKYqvRCwEQw02lbOAxCKDHJ7tKVn-C2OQSZs6Cr2kCOHM5TLQxIfWzrJzTU4dxHicqi-tDUeMF4xo2WKxIH8aNtVmx5TIMa_9tWaOb3npDRkUvk7MGCuk6CFsQSTNigqvmC2gSRtBnrVtdnCDPZF_yNBBLCpyOVFAEeR18pnqSdCw0HS4nkDHeA9yZ1RF4t6eYRM4xK8DrWAWMdDmT2_qdw4g4wi8WSjXeCCUmE3kOaUcmd3pXgS46qtb8lZdmw4PTriZUt9szpaRe5AgxO7F_Up0b8mbNZWDJ8kF7G9gHR6vToa28kWq_TwF7nOEtbdzPHZeL-vAk73D2cCJTmJB4einKCrsGLe4R-MOzKcWj7wu2Fjt1IvS6aqVRVy1uvwIRiWrUyhHDxSxsFGezirHsGKbBQej5Cytn40BaOlmFHMvK3S9vfpL8m2XjqmBW_q6sw4HnNYMBKTb9DzKrQb2aj3apEyuANeh-VSImgX09kiIWM70Q-QpUAI794N8bnJ3iykg6kjWro-3EZb44wOjuvMNGUp-HOzAbwY2P_Gk2y-Opo3UkXwoJV66n0LyMVgt8tLI0Dm7Q_VEsyw9u-b_mn3-uxblZR4fllyfCur7Ew8R2jiIZcry-o5iE2M1lIL7zlDy3-yCMhN75VXMI6EVvtuIwaHmLMubvu2N62nGVJ2n039LVAkByCoGzRBNIOXVgBy-aY49xnScn0F_fzyjQIK7VdsObJnO9Tg5bjurfRQ_ZQNx-g4CPG-vne2FkuI1ft7bLEcIMmVk0zn67VHwmcQ4rcBFHW_d_Vw2OfozrS0QjpcOIX1oC7h_MZIHb2SMzLKu3VyhqroZkYSz-Eebl3ieQuF8NO9d4kRJin_OBxki1_7IFSa239hcMoMk3x3hmHmsfeOb3msEtCdcXvcQOdt-inyjuNlvd_uhqs-p2LWcW32GBpAfYpnP26wMT-DHo1lX7k4R4BEsxWIelC-_-GnGpx7R7wsxvHmamyjjKeNfv_RvUeBYGqC67q64HAZY3ggQyyuzCanzXukzevdO5CkmxB61cNAE4OfKYcuPpITWE9XStCjA_hQ7xsnSz6tG0v69Ul9-qWxrmlChf9aubNe6HyBifPh7l7udwRnaVumiTMN0HaM232VdT-t5HoetaqSCnKIj46ql4iEZOqCuA/download"
    
    response = requests.head(url, allow_redirects=True)
    file_size = int(response.headers.get('content-length', 0))
    
    response = requests.get(url, stream=True)
    
    if response.status_code == 200:
        filename = "atlasv2.tar.gz"
    
        with open(filename, 'wb') as file, tqdm(
            desc=filename,
            total=file_size,
            unit='B',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=1024):
                file.write(chunk)
                bar.update(len(chunk))
    
        print(f"File downloaded successfully and saved as {filename}")
    else:
        print(f"Failed to download file. Status code: {response.status_code}")

    !tar -xzvf atlasv2.tar.gz -C atlas_dataset/

In [17]:
import pandas as pd
import json

def load_data(file_path):
    log_data = []
    with open(file_path, 'r') as file:
        for line in file:
            doc = json.loads(line)
            timestamp = doc.get('device_timestamp') 

            if doc.get('type') == 'endpoint.event.crossproc':
                parent_node_id = doc.get('process_guid')
                parent_node_label = doc.get('process_path')
                child_node_id = doc.get('crossproc_guid')
                child_node_label = doc.get('crossproc_name')
                edge_label = 'crossproc'
                data = {
                    'action': edge_label,
                    'actorID': parent_node_id,
                    'objectID': child_node_id,
                    'object': 'PROCESS',
                    'actorname': parent_node_label,
                    'objectname': child_node_label,
                    'timestamp': timestamp
                }
                log_data.append(data)

            elif doc.get('type') == 'endpoint.event.procstart':  
                parent_node_id = doc.get('process_guid')
                parent_node_label = doc.get('process_path')
                child_node_id = doc.get('childproc_guid')
                child_node_label = doc.get('childproc_name')
                edge_label = 'procstart'
                data = {
                    'action': edge_label,
                    'actorID': parent_node_id,
                    'objectID': child_node_id,
                    'object': 'PROCESS',
                    'actorname': parent_node_label,
                    'objectname': child_node_label,
                    'timestamp': timestamp
                }
                log_data.append(data)

            elif doc.get('type') == 'endpoint.event.filemod':  
                parent_node_id = doc.get('process_guid')
                parent_node_label = doc.get('process_path')
                child_node_id = doc.get('filemod_name')
                child_node_label = doc.get('filemod_name')
                edge_label = 'filemod'
                data = {
                    'action': edge_label,
                    'actorID': parent_node_id,
                    'objectID': child_node_id,
                    'object': 'FILE',
                    'actorname': parent_node_label,
                    'objectname': child_node_label,
                    'timestamp': timestamp
                }
                log_data.append(data)

            elif doc.get('type') == 'endpoint.event.netconn':  
                parent_node_id = doc.get('process_guid')
                parent_node_label = doc.get('process_path')
                child_node_id = doc.get('remote_ip')
                child_node_label = doc.get('remote_ip')
                edge_label = 'netconn'
                data = {
                    'action': edge_label,
                    'actorID': parent_node_id,
                    'objectID': child_node_id,
                    'object': 'SOCKET',
                    'actorname': parent_node_label,
                    'objectname': child_node_label,
                    'timestamp': timestamp
                }
                log_data.append(data)
                
    df = pd.DataFrame(log_data)
    return df

In [18]:
host1 = "atlas_dataset/atlasv2/data/benign/h1/cbc-edr/edr-h1-benign.jsonl"
host2 = "atlas_dataset/atlasv2/data/benign/h2/cbc-edr/edr-h2-benign.jsonl"

In [19]:
dfh1 = load_data(host1)
dfh2 = load_data(host2)

In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import torch
from torch_geometric.data import Data
import os
import torch.nn.functional as F
import json
import warnings
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
warnings.filterwarnings('ignore')
from torch_geometric.loader import NeighborLoader
import multiprocessing

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%matplotlib inline

In [None]:
from pprint import pprint
import gzip
from sklearn.manifold import TSNE
import json
import copy
import os

In [None]:
num_of_ctg = 10
learning_rounds = 5
epochs = 10
hosts = list(hostdfs.keys())

In [None]:
'''
This is the main featurizer. It constructs the graph for the Cadets dataset.

Args:
    df (DataFrame): This is the main dataframe containing all the system events from the Cadets dataset.

return:
    features (list): Contains word2vec encoded feature vectors for each node
    feat_labels (list): Contains label for each node
    edge_index (list): Contains information about edges between nodes in the graph.
    mapp (list): contains id of each node
'''

tokens = ['PROCESS',
          'FILE',
          'SOCKET']

def prepare_graph(df):
    global tokens
    dummies = {token: index for index, token in enumerate(tokens)}
    
    df['actor_label'] = df['actor_type'].map(dummies)
    df['object_label'] = df['object'].map(dummies)
    
    nodes = {}
    labels = {}
    for col in ['actorID', 'objectID']:
        unique_ids = df[col].unique()
        for uid in unique_ids:
            nodes[uid] = []
        if col == 'actorID':
            labels.update(df.set_index('actorID')['actor_label'].to_dict())
        else:
            labels.update(df.set_index('objectID')['object_label'].to_dict())
    
    for _, row in df.iterrows():
        nodes[row['actorID']].extend([row['actorname'], row['action'], row['objectname']])
        nodes[row['objectID']].extend([row['actorname'], row['action']])
    
    edges = list(zip(df['actorID'], df['objectID']))

    mapp = list(nodes.keys())
    features = [nodes[node_id] for node_id in mapp]
    feat_labels = [labels[node_id] for node_id in mapp]
    edge_index = [[], []]
    index_map = {node_id: index for index, node_id in enumerate(mapp)}
    
    for src, dst in edges:
        edge_index[0].append(index_map[src])
        edge_index[1].append(index_map[dst])
    
    all_procs = list(df['actorID'].unique())
    idx_to_proc = {index: proc for index, proc in enumerate(all_procs)}

    return features, feat_labels, edge_index, mapp, all_procs, idx_to_proc

In [None]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv, GATConv
import torch.nn.functional as F
import torch.nn as nn

class GCN(torch.nn.Module):
    def __init__(self,in_channel,out_channel):
        super(GCN, self).__init__()
        self.conv1 = SAGEConv(in_channel, 32, normalize=True)
        self.conv2 = SAGEConv(32, out_channel, normalize=True)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
    
        x = self.encode(x, edge_index)
        return F.softmax(x, dim=1)
    
    def encode(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        
        x = self.conv1(x, edge_index)
        x = F.tanh(x)
        x = self.conv2(x, edge_index)
        return x
    
    def freeze_conv_layers(self):
        for param in self.conv1.parameters():
            param.requires_grad = False
        for param in self.conv2.parameters():
            param.requires_grad = False

In [None]:
def combine_word2vec_models(models):
    # Convert the first model's vectors into a dictionary
    unified_dict = {word: models[0].wv[word] for word in models[0].wv.index_to_key}

    # Iterate through the remaining models
    for model in models[1:]:
        model_dict = {word: model.wv[word] for word in model.wv.index_to_key}

        # Iterate through words in the current model
        for word, vector in model_dict.items():
            if word in unified_dict:
                # Average the vectors for overlapping words
                unified_dict[word] = (unified_dict[word] + vector) / 2.0
            else:
                # Add unique words directly
                unified_dict[word] = vector
                
    unified_dict = {word: vector.tolist() if isinstance(vector, np.ndarray) else vector for word, vector in unified_dict.items()}

    return unified_dict

In [None]:
from gensim.models.callbacks import CallbackAny2Vec
import gensim
from gensim.models import Word2Vec
from multiprocessing import Pool
from itertools import compress
from tqdm import tqdm
import time

In [None]:
class EpochSaver(CallbackAny2Vec):
    '''Callback to save model after each epoch.'''

    def __init__(self,filename):
        self.epoch = 0
        self.filename = filename

    def on_epoch_end(self, model):
        model.save(self.filename)
        self.epoch += 1

In [None]:
class EpochLogger(CallbackAny2Vec):
    '''Callback to log information about training'''

    def __init__(self):
        self.epoch = 0

    def on_epoch_begin(self, model):
        print("Epoch #{} start".format(self.epoch))

    def on_epoch_end(self, model):
        print("Epoch #{} end".format(self.epoch))
        self.epoch += 1

In [None]:
from sklearn.utils import class_weight
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

model = GCN(30,len(tokens)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [None]:
def train_word2vec_models():
    global hosts,hostdfs
    
    for h in hosts:
        print("Running host:",h)
        logger = EpochLogger()
        saver = EpochSaver(f"Content_FL_Exp/{h}.model")

        df = hostdfs[h]
        df = df[df['actor_type'] == 'PROCESS'] 
        df = df[df['object'].isin(tokens)] 

        phrases,feat_labels,edge_index,mapp,all_procs,idx_to_proc = prepare_graph(df)
        word2vec = Word2Vec(sentences=phrases, vector_size=30, window=5, min_count=1, workers=5,epochs=100,callbacks=[saver,logger])

In [None]:
logger = EpochLogger()
saver = EpochSaver(f"Content_FL_Exp/Atlasv2_trustwatch_host1.model")
df = load_data(host1)
df = df[df['actor_type'] == 'PROCESS'] 
phrases,labels,edge_index,mapp,all_procs,idx_to_proc = prepare_graph(df)
word2vec = Word2Vec(sentences=phrases, vector_size=30, window=5, min_count=1, workers=5, epochs=100, callbacks=[saver,logger])

logger = EpochLogger()
saver = EpochSaver(f"Content_FL_Exp/Atlasv2_trustwatch_host2.model")
df = load_data(host2)
df = df[df['actor_type'] == 'PROCESS'] 
phrases,labels,edge_index,mapp,all_procs,idx_to_proc = prepare_graph(df)
word2vec = Word2Vec(sentences=phrases, vector_size=30, window=5, min_count=1, workers=5, epochs=100, callbacks=[saver,logger])

In [None]:
word_models = []
for i in range(1,3):
    word2vec = Word2Vec.load(f"Content_FL_Exp/Atlasv2_trustwatch_host{i}.model")
    word_models.append(word2vec)
    
global_word = combine_word2vec_models(word_models)

with open('Content_FL_Exp/Atlasv2_trustwatch_global.json', 'w') as json_file:
    json.dump(global_word, json_file)

In [None]:
def load_word_model():
    with open('Content_FL_Exp/Atlasv2_trustwatch_global.json', 'r') as json_file:
        loaded_dict = json.load(json_file)

    converted_dict = {word: np.array(vector) for word, vector in loaded_dict.items()}
    return converted_dict

In [None]:
from sklearn.utils import class_weight
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

In [None]:
from collections import Counter
word2vec = load_word_model()

def infer(doc):
    global word2vec
    temp = dict(Counter(doc))
    emb = np.zeros(30)
    count = 0
    for k,v in temp.items():
        if k in word2vec:
            emb = emb + word2vec[k]*v
            count = count + 1
    emb = emb / count
    return emb

In [None]:
def init_gnns():
    global num_of_ctg,tokens
    n = num_of_ctg 
    gnn_models = []
    for i in range(n):
        m = GCN(30,len(tokens)).to(device)
        gnn_models.append(m)
    return gnn_models

In [None]:
def define_categories(pids):
    global num_of_ctg
    n = num_of_ctg - 1
    ctg = set(pids)
    ctg = list(ctg)
    k, m = divmod(len(ctg), n)
    return [set(ctg[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n)]

In [None]:
def map_pids_to_category_indices(pids, categories):
    pid_to_category_index = {}
    
    for pid in pids:
        for category_index, category_set in enumerate(categories):
            if pid in category_set:
                pid_to_category_index[pid] = category_index 
                break 
    
    return pid_to_category_index

In [None]:
from torch.nn import CrossEntropyLoss
from sklearn.utils import class_weight
import copy

templates = init_gnns()

def train_gnn_func(nodes,labels,edges,mapp,pids,idx_to_pid):
    
    global categories ,epochs
    
    pid_to_gnn_index = map_pids_to_category_indices(pids, categories)
    
    set_pids = set(pids)

    proc_index = [i for i in range(len(mapp)) if mapp[i] in set_pids]

    train_splits = [[] for _ in range(len(categories))]

    for i in proc_index:
        pname = idx_to_pid[str(i)]
        split_indx = pid_to_gnn_index[pname]
        train_splits[split_indx].append(int(i))
        
    local_models = [copy.deepcopy(x) for x in templates]
    
    for i in range(len(local_models)-1):
            
        if len(train_splits[i]) == 0:
            local_models[i] = None
        else:
            if f"target_e5_cadets_global{i}.pth" in os.listdir("Content_FL_Exp"):
                local_models[i].load_state_dict(torch.load(f"Content_FL_Exp/target_e5_cadets_global{i}.pth"))

            optimizer = torch.optim.Adam(local_models[i].parameters(), lr=0.01, weight_decay=5e-4)
            criterion = CrossEntropyLoss()

            graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
            mask = torch.tensor([False]*graph.num_nodes, dtype=torch.bool)
            mask[train_splits[i]] = True
            
            def get_neighbors(edge_index, nodes):
                neighbors = []
                for node in nodes:
                    mask = edge_index[0] == node
                    neighbors.extend(edge_index[1, mask].tolist())
                return torch.tensor(list(set(neighbors)), dtype=torch.long)

            one_hop_neighbors = get_neighbors(graph.edge_index, train_splits[i])
            two_hop_neighbors = get_neighbors(graph.edge_index, one_hop_neighbors)
            two_hop_neighbors = two_hop_neighbors[~mask[two_hop_neighbors]]
            mask[two_hop_neighbors] = True
            
            for epoch in range(epochs):
                print(f'Training GNN Category {i} Model for Epoch {epoch}')

                loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000,input_nodes=mask)
                total_loss = 0
                for subg in loader:
                    local_models[i].train()
                    optimizer.zero_grad() 
                    out = local_models[i](subg.x, subg.edge_index) 
                    loss = criterion(out, subg.y) 
                    loss.backward() 
                    optimizer.step()      
                    total_loss += loss.item() * subg.batch_size
                print("Loss: ", total_loss / mask.sum().item(), '\n')
    
    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
    optimizer = torch.optim.Adam(local_models[-1].parameters(), lr=0.01, weight_decay=5e-4)
    criterion = CrossEntropyLoss()
    
    for epoch in range(epochs):
        print(f'Training Catch all GNN Category Model for Epoch {epoch}')    
        local_models[-1].train()
        optimizer.zero_grad() 
        out = local_models[-1](graph.x, graph.edge_index) 
        loss = criterion(out, graph.y) 
        loss.backward() 
        optimizer.step()      
        print(f"Epoch: {epoch}, Loss: {loss.item()}")

    return local_models

In [None]:
procs_total = []
data_cache = {}
categories = None

def load_clients_data():
    
    global data_cache,categories,procs_total,tokens,hostdfs,hosts
    
    for name in hosts:
        df = hostdfs[name]
        df = df[df['actor_type'] == 'PROCESS'] 
        df = df[df['object'].isin(tokens)]    

        docs,labels,edges,mapp,pids,idx_to_pid = prepare_graph(df)
        data_cache[name] = [docs,labels,edges,mapp,pids,idx_to_pid]
        procs_total = procs_total + pids

    categories = define_categories(procs_total)

In [None]:
def client_handling_loop(client_id):    
    print(f"Running Setup on Client {client_id} \n")
    
    docs,labels,edges,mapp,pids,idx_to_pid = data_cache[client_id]
    
    nodes_feat = []
    for x in docs:
        nodes_feat.append(infer(x)) 
        
    trained_local_models = train_gnn_func(nodes_feat,labels,edges,mapp,pids,idx_to_pid)
    return trained_local_models

In [None]:
def server_aggregate(all_models):
    global_models = copy.deepcopy(templates)
    
    for l in range(len(all_models)):
        
        current_models = all_models[l]
        current_models = [x for x in current_models if x != None]
        
        if not len(current_models) == 0:
        
            global_dict = global_models[l].state_dict()

            for k in global_dict.keys():
                param_list = [current_models[i].state_dict()[k] for i in range(len(current_models))]
                global_dict[k] = torch.stack(param_list, 0).mean(0)

            global_models[l].load_state_dict(global_dict)
            torch.save(global_models[l].state_dict(), f"Content_FL_Exp/target_atlas_global{l}.pth")
                   
    return global_models

In [None]:
import random
def perform_federated_learning(n_clients):
    client_models = []
    for c in n_clients:
        local_gnns = client_handling_loop(c)
        client_models.append(local_gnns)
    return client_models

In [None]:
!rm Content_FL_Exp/target_atlas_global*.pth

In [None]:
with open('Content_FL_Exp/e5_cadets_ensemble_ben.json', 'r') as f:
    data_cache = json.load(f)

proc_total = []
for x in hosts:
    proc_total = proc_total + data_cache[x][-2]
    
categories = define_categories(proc_total)

In [None]:
load_clients_data()

In [None]:
for r in range(learning_rounds):
    print(f"Federated Learning Round Number: {r}\n")
    client_models = perform_federated_learning(hosts)
    arranged_models =  [list(group) for group in zip(*client_models)]
    global_models = server_aggregate(arranged_models)

In [None]:
'''
This function is used for constructing neighborhood around a given 
set of nodes for backwards or forward tracking
'''
from itertools import compress
from torch_geometric import utils

def construct_neighborhood(ids,mapp,edges,hops):
    if hops == 0:
        return set()
    else:
        neighbors = set()
        for i in range(len(edges[0])):
            if mapp[edges[0][i]] in ids:
                neighbors.add(mapp[edges[1][i]])
            if mapp[edges[1][i]] in ids:
                neighbors.add(mapp[edges[0][i]])
        return neighbors.union( construct_neighborhood(neighbors,mapp,edges,hops-1) )

In [None]:
'''
This function logs the evaluation metrics.
'''

def helper(MP,all_pids,GP,edges,mapp):

    TP = MP.intersection(GP)  
    FP = MP - GP              
    FN = GP - MP              
    TN = all_pids - (GP | MP)
    
    two_hop_gp = construct_neighborhood(GP,mapp,edges,2)
    two_hop_tp = construct_neighborhood(TP,mapp,edges,2)
    FPL = FP - two_hop_gp
    TPL = TP.union(FN.intersection(two_hop_tp))
    FN = FN - two_hop_tp
    
    alerts = TP.union(FP)

    TP,FP,FN,TN = len(TPL),len(FPL),len(FN),len(TN)
    
    FPR = FP / (FP+TN)
    TPR = TP / (TP+FN)

    print(f"Number of True Positives: {TP}")
    print(f"Number of Fasle Positives: {FP}")
    print(f"Number of False Negatives: {FN}")
    print(f"Number of True Negatives: {TN}\n")

    prec = TP / (TP + FP)
    print(f"Precision: {prec}")

    rec = TP / (TP + FN)
    print(f"Recall: {rec}")

    fscore = (2*prec*rec) / (prec + rec)
    print(f"Fscore: {fscore}\n")
    
    #return alerts
    return TP, FP, FN, TN

In [None]:
def generate_groundtruth():
    nodeids  = [[
    return nodeids

In [None]:
GT_mal = generate_groundtruth()

In [None]:
def load_data_test(path):
    
    df = load_data(path)
    
    df = df[df['actor_type'] == 'PROCESS'] 
    df = df[df['object'].isin(tokens)]

    docs,labels,edges,mapp,pids,idx_to_pid = prepare_graph(df)
    return [docs,labels,edges,mapp,pids,idx_to_pid]

In [None]:
data_mal = load_data_test()

In [None]:
def run_evaluation(thresh):    
    global word2vec,tokens,GT_mal
            
    phrases,labels,edges,mapp,pids,idx_to_pid = data_mal

    model = GCN(30,len(tokens)).to(device)
    word2vec = load_word_model()

    nodes = [infer(x) for x in phrases]
    nodes = np.array(nodes)  

    all_ids = set(mapp)
        
    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))

    flag = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)
    
    for m_n in range(num_of_ctg):
        if f"target_e5_cadets_global{m_n}.pth" in os.listdir("Content_FL_Exp"): 
            model.load_state_dict(torch.load(f"Content_FL_Exp/target_e5_cadets_global{m_n}.pth",map_location=torch.device('cpu')))
            
        model.eval()
        out = model(graph.x, graph.edge_index)
    
        sorted, indices = out.sort(dim=1,descending=True)
        conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]
        conf = (conf - conf.min()) / conf.max()
    
        pred = indices[:,0]
        cond = (pred == graph.y) & (conf >= thresh)
        flag[cond] = torch.logical_and(flag[cond], torch.tensor([False]*len(flag[cond]), dtype=torch.bool))

    index = utils.mask_to_index(flag).tolist()
    ids = set([mapp[x] for x in index])
    metrics = helper(set(ids),set(all_ids),GT_mal,edges,mapp)
    return metrics

In [None]:
run_evaluation(0.95)