# Active Directory Environment Classification based on Vulnerability

## Imports and globals

### Imports

In [4]:
import pandas as pd
import networkx as nx
import numpy as np
from py2neo import Graph
import random
import torch
from torch_geometric.data import Data, HeteroData
import json
import os
from torch_geometric.nn import to_hetero, GAT, SAGEConv
from torch_geometric.explain import Explanation
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer


### Global variables

In [5]:
# Database connection details
graph = Graph("http://localhost:7474", auth=("neo4j", "bloodhoundcommunityedition"))

# Object types and their corresponding properties
object_types_and_properties = {
    'Domain': ['name', 'objectid', 'highvalue'],
    'User': ['name', 'objectid', 'admincount', 'dontreqpreauth', 'pwdneverexpires', 'hasspn', 'highvalue', 'savedcredentials',
             'passwordnotreqd', 'pwdlastset', 'lastlogon', 'unconstraineddelegation', 'enabled', 'sensitive'],
    'Computer': ['name', 'objectid', 'operatingsystem', 'enabled', 'haslaps', 'highvalue', 'lastlogontimestamp', 
                 'pwdlastset', 'unconstraineddelegation', 'privesc', 'creddump', 
                 'exploitable'],
    'Group': ['name', 'objectid', 'highvalue', 'admincount'],
    'OU': ['name', 'objectid', 'highvalue', 'blocksInheritance'],
    'GPO': ['name', 'objectid', 'exploitable'],
    'Container': ['name', 'objectid', 'highvalue']
}

# Relationship types
relationship_types = [
    'AddMember',
    'AddSelf',
    'AdminTo',
    'AllExtendedRights',
    'AllowedToAct',
    'AllowedToDelegate',
    'CanPSRemote',
    'CanRDP',
    'Contains',
    'ExecuteDCOM',
    'ForceChangePassword',
    'GenericAll',
    'GenericWrite',
    'GetChanges',
    'GetChangesAll',
    'GpLink',
    'HasSession',
    'MemberOf',
    'Owns',
    'ReadLAPSPassword',
    'SQLAdmin',
    'WriteDacl',
    'WriteOwner'
]

# OS possibilities
global_os_categories = ['Windows Server 2003 Enterprise Edition', 'Windows Server 2008 Datacenter', 'Windows Server 2008 Enterprise', 
                        'Windows Server 2008 R2 Datacenter', 'Windows Server 2008 R2 Enterprise', 'Windows Server 2008 R2 Standard', 
                        'Windows Server 2008 Standard', 'Windows Server 2012 Datacenter', 'Windows Server 2012 R2 Datacenter', 
                        'Windows Server 2012 R2 Standard', 'Windows Server 2012 Standard', 'Windows Server 2016 Datacenter', 
                        'Windows Server 2016 Standard']

# Object property types
object_property_types = {
    "Domain": {
        "Name": "string",
        "Objectid": "string",
        "Highvalue": "boolean"
    },
    "User": {
        "Name": "string",
        "Objectid": "string",
        "Admincount": "boolean",
        "Dontreqpreauth": "boolean",
        "Pwdneverexpires": "boolean",
        "Hasspn": "boolean",
        "Highvalue": "boolean",
        "Savedcredentials": "boolean",
        "Passwordnotreqd": "boolean",
        "Pwdlastset": "numerical",
        "Lastlogon": "numerical",
        "Unconstraineddelegation": "boolean",
        "Enabled": "boolean",
        "Sensitive": "boolean"
    },
    "Computer": {
        "Name": "string",
        "Objectid": "string",
        "Operatingsystem": "categorical",
        "Enabled": "boolean",
        "Haslaps": "boolean",
        "Highvalue": "boolean",
        "Lastlogontimestamp": "numerical",
        "Pwdlastset": "numerical",
        "Unconstraineddelegation": "boolean",
        "Privesc": "boolean",
        "Creddump": "boolean",
        "Exploitable": "boolean"
    },
    "Group": {
        "Name": "string",
        "Objectid": "string",
        "Highvalue": "boolean",
        "Admincount": "boolean"
    },
    "OU": {
        "Name": "string",
        "Objectid": "string",
        "Highvalue": "boolean",
        "Blocksinheritance": "boolean"
    },
    "GPO": {
        "Name": "string",
        "Objectid": "string",
        "Exploitable": "boolean"
    },
    "Container": {
        "Name": "string",
        "Objectid": "string",
        "Highvalue": "boolean"
    }
}

## Functions for handling graph database

In [6]:
def clear_neo4j_database(session):
    # Delete nodes and edges with batching into 10k objects - From DBCreator
    total = 1
    while total > 0:
        result = session.run(
            "MATCH (n) WITH n LIMIT 10000 DETACH DELETE n RETURN count(n)")
        for r in result:
            total = int(r['count(n)'])
    session.run("CALL apoc.schema.assert({},{},true);")
    
        # Remove constraint - From DBCreator
    for constraint in session.run("SHOW CONSTRAINTS"):
        session.run("DROP CONSTRAINT {}".format(constraint['name']))

    icount = session.run(
        "SHOW INDEXES YIELD name RETURN count(*)")
    for r in icount:
        ic = int(r['count(*)'])
            
    while ic >0:
    
        showall = session.run(
            "SHOW INDEXES")
        for record in showall:
            name = (record['name'])
            session.run("DROP INDEX {}".format(name))
        ic = 0
        
    # Setting constraints
    constraints = [
            "CREATE CONSTRAINT FOR (n:Base) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Domain) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Computer) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:User) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:OU) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:GPO) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Compromised) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Group) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Container) REQUIRE n.neo4jImportId IS UNIQUE;",
    ]

    for constraint in constraints:
        try:
            session.run(constraint)
        except:
            continue
    
    session.run("match (a) -[r] -> () delete a, r")
    session.run("match (a) delete a")

In [7]:
def load_graph_from_json(session, file_path):
    """
    Loads a graph from a JSON file into Neo4j.
    """
    with open(file_path, 'r') as f:
        query = f"PROFILE CALL apoc.periodic.iterate(\"CALL apoc.import.json('{file_path}')\", \"RETURN 1\", {{batchSize:1000}})"
        session.run(query)

## Functions for extracting features and creating dataset

#### Extract features from neo4j database

In [8]:
# Function to extract features from the Neo4j database for a specific object type and returns a Pandas DataFrame.
def extract_features(graph, labels, properties):

    # Create the RETURN clause dynamically based on the provided properties
    return_clause = ", ".join([f"n.{prop} AS node_{prop}" for prop in properties])

    # Define the Cypher query with labels and properties
    query = f"""
    MATCH (n:{labels})
    RETURN 
        id(n) AS node_id, 
        {return_clause}
    """

    # Execute the query and store the results in a Pandas DataFrame
    result = graph.run(query)
    df = pd.DataFrame(result)

    if df.empty:  # Check if the DataFrame is empty
        df = pd.DataFrame(columns=['Node ID'] + [prop.title() for prop in properties]) 

    # Add headers to the DataFrame (adjust based on properties)
    df.columns = ['Node ID'] + [prop.title() for prop in properties]

    return df

In [9]:
def get_node_properties(graph, label):
    query = f"""
    MATCH (n:{label})
    WITH keys(n) AS keys
    UNWIND keys AS key
    RETURN DISTINCT key
    """
    result = graph.run(query)
    return [record["key"] for record in result]

all_possible_object_types_and_properties = {
    label: get_node_properties(graph, label) 
    for label in ['Domain', 'User', 'Computer', 'Group', 'GPO', 'Container']
}

In [10]:
# Function to extract relationships from the Neo4j database and returns a Pandas DataFrame.
def extract_relationships(graph, rel_types):
    
    # List to store DataFrames for each relationship type
    dfs = []

    for rel_type in rel_types:
        # Define the Cypher query with dynamic relationship type
        query = f"""
        MATCH (source)-[r:{rel_type}]->(target)
        RETURN 
            id(source) AS source_id,
            id(target) AS target_id,
            TYPE(r) AS relationship_type
        """

        # Execute the query and store the results in a Pandas DataFrame
        result = graph.run(query)
        df = pd.DataFrame(result)
        
        if not df.empty:

            # Add headers to the DataFrame
            df.columns = ['Source ID', 'Target ID', 'Relationship Type']

            dfs.append(df)

    # Concatenate all DataFrames into a single DataFrame
    return pd.concat(dfs, ignore_index=True)

#### Filter dataframes

In [11]:
def filter_dataframes(dfs):
    for object_name in dfs:
        for property in dfs[object_name].columns:
            # Take first element if it is a list
            dfs[object_name][property] = dfs[object_name][property].apply(lambda x: x[0] if isinstance(x, list) else x)
            # Set boolean to False if null
            dfs[object_name][property] = dfs[object_name][property].apply(lambda x: False if x == "null" else x)
    return dfs

#### Helper functions

In [12]:
import pandas as pd
import torch

def create_tensors_from_dataframe(df, object_type, object_property_types, missing_features):
    node_ids = df['Node ID'].values
    all_property_values = []
    for column in df.columns:
        if column in object_property_types[object_type] and column not in missing_features.get(object_type, []):
            property_type = object_property_types[object_type][column]
            if property_type == 'boolean':
                # Convert to 1s and 0s
                all_property_values.append(df[column].astype(int).values)  
            elif property_type == 'categorical':
                # One-hot encoding
                one_hot_df = pd.get_dummies(df[column], dtype=int)
                # Add missing categories with 0s
                for category in global_os_categories:
                    if category not in one_hot_df.columns:
                        one_hot_df[category] = 0 

                # Ensure consistent order of columns
                one_hot_df = one_hot_df[global_os_categories]  

                for col in one_hot_df.columns:
                    all_property_values.append(one_hot_df[col].values)

            elif property_type == 'numerical':
                # Fill NaN with 0
                numeric_values = pd.to_numeric(df[column], errors='coerce').fillna(0).values.reshape(-1, 1)  # Reshape for scaler

                # Apply StandardScaler
                scaler = StandardScaler()
                scaled_values = scaler.fit_transform(numeric_values)

                all_property_values.append(scaled_values.flatten())  # Flatten back to 1D
            elif property_type == 'string':
                # Ignore string columns
                pass  
            else:
                raise ValueError(f"Unknown property type '{property_type}' for column '{column}'")
        else:
            pass
    return node_ids, torch.tensor(all_property_values, dtype=torch.float).T

In [13]:
# Helper function to determine the object type based on the Node ID
def get_object_type(node_id, source_id_maps):
    for obj_type, ids in source_id_maps.items():
        if node_id in ids:
            return obj_type
    #raise ValueError(f"Node ID {node_id} not found in any object type")
    return None

In [14]:
def clear_neo4j_database(session):
    # Delete nodes and edges with batching into 10k objects - From DBCreator
    total = 1
    while total > 0:
        result = session.run(
            "MATCH (n) WITH n LIMIT 10000 DETACH DELETE n RETURN count(n)")
        for r in result:
            total = int(r['count(n)'])
    session.run("CALL apoc.schema.assert({},{},true);")
    
        # Remove constraint - From DBCreator
    for constraint in session.run("SHOW CONSTRAINTS"):
        session.run("DROP CONSTRAINT {}".format(constraint['name']))

    icount = session.run(
        "SHOW INDEXES YIELD name RETURN count(*)")
    for r in icount:
        ic = int(r['count(*)'])
            
    while ic >0:
    
        showall = session.run(
            "SHOW INDEXES")
        for record in showall:
            name = (record['name'])
            session.run("DROP INDEX {}".format(name))
        ic = 0
        
    # Setting constraints
    constraints = [
            "CREATE CONSTRAINT FOR (n:Base) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Domain) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Computer) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:User) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:OU) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:GPO) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Compromised) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Group) REQUIRE n.neo4jImportId IS UNIQUE;",
            "CREATE CONSTRAINT FOR (n:Container) REQUIRE n.neo4jImportId IS UNIQUE;",
    ]

    for constraint in constraints:
        try:
            session.run(constraint)
        except:
            continue
    
    session.run("match (a) -[r] -> () delete a, r")
    session.run("match (a) delete a")

In [15]:
def load_graph_from_json(session, file_path):
    """
    Loads a graph from a JSON file into Neo4j.
    """
    with open(file_path, 'r') as f:
        query = f"PROFILE CALL apoc.periodic.iterate(\"CALL apoc.import.json('{file_path}')\", \"RETURN 1\", {{batchSize:1000}})"
        session.run(query)

#### Create PyG heterogenous dataset from Pandas DataFrame

In [16]:
def create_heterogeneous_graph(object_dfs, relationship_df, object_property_types, missing_features):
    
    # Initialize empty objects
    data = HeteroData()
    source_id_maps = {}
    skipped_edges = 0
    skipped_nodes = 0

    # Add Nodes and Features
    for object_name in object_dfs:
        node_ids, data[object_name].x = create_tensors_from_dataframe(object_dfs[object_name], object_name, object_property_types, missing_features)
        source_id_maps[object_name] = node_ids  

    # Add Edges
    edge_index_dict = {}  # Dictionary to store edge indices for each edge type
    for _, row in relationship_df.iterrows():
        try:
            source_id, target_id, rel_type = row
            source_type = get_object_type(source_id, source_id_maps)
            target_type = get_object_type(target_id, source_id_maps)
            source_index = np.where(np.isin(source_id_maps[source_type], source_id))[0]  
            target_index = np.where(np.isin(source_id_maps[target_type], target_id))[0]

            # source_index = np.where(source_id_maps[source_type] == source_id)[0]
            # target_index = np.where(source_id_maps[target_type] == target_id)[0]

            # Store edge indices in the dictionary
            edge_type = (source_type, rel_type, target_type)
            if edge_type not in edge_index_dict:
                edge_index_dict[edge_type] = [[], []]  # Initialize with empty lists
            edge_index_dict[edge_type][0].extend(source_index)
            edge_index_dict[edge_type][1].extend(target_index)

        except ValueError as e:
            skipped_edges += 1
        except KeyError as e:
            skipped_nodes += 1

    # Create edge_index tensors from the accumulated indices
    for edge_type, (source_indices, target_indices) in edge_index_dict.items():
        source_type, rel_type, target_type = edge_type
        data[source_type, rel_type, target_type].edge_index = torch.tensor([source_indices, target_indices], dtype=torch.long)

    print(f"{skipped_edges} edges skipped")
    print(f"{skipped_nodes} nodes skipped")
    return data

## Building dataset

### Test data

In [29]:
clear_neo4j_database(graph)

In [None]:
object_dfs = {}
relationship_df = pd.DataFrame()

for object_type, properties in object_types_and_properties.items():
    # object_dfs[object_type] = extract_features(graph, object_type, properties).dropna()
    imputer = SimpleImputer(strategy='most_frequent')  # Or 'median', 'most_frequent'
    object_dfs[object_type] = extract_features(graph, object_type, properties)

    print(f"Object dataframe for {object_type}: {object_dfs[object_type]}")

    for property in object_dfs[object_type].columns:
        # Take first element if it is a list
        object_dfs[object_type][property] = object_dfs[object_type][property].apply(lambda x: x[0] if isinstance(x, list) else x)
        # Set boolean to False if null
        object_dfs[object_type][property] = object_dfs[object_type][property].apply(lambda x: False if x == "null" or x is None else x) 
    if not object_dfs[object_type].empty:
        object_dfs[object_type][:] = imputer.fit_transform(object_dfs[object_type])  # Impute in place

relationship_df = extract_relationships(graph, relationship_types)

#print("Filtering dataframes")
#object_dfs_filtered = filter_dataframes(object_dfs)

# Identify missing features
missing_features = {
    object_type: list(set(object_types_and_properties[object_type]) - set(df.columns))
    for object_type, df in object_dfs.items()
}

print(f"Missing features: {missing_features}")

print("Creating heterogenous PyG graph")
real_graph_data_rtm_vuln = create_heterogeneous_graph(object_dfs, relationship_df, object_property_types, missing_features)
real_graph_data_rtm_vuln.y = torch.tensor([1])


### Vulnerable environments

In [None]:
# Directory for ADSynth datasets
data_dir = ""

vulnerable_dataset = []

#Main generation function
print("===== Start =====")
for filename in os.listdir(data_dir):
    if filename.endswith(".json"):
        print("----- Starting process for new graph -----")
        print(f"Now processing: {filename} ")

        file_path = os.path.join(data_dir, filename)
        print("Clearing database")
        clear_neo4j_database(graph)

        print("Loading json file")
        load_graph_from_json(graph, file_path)

        object_dfs = {}
        relationship_df = pd.DataFrame()
        for object_type, properties in object_types_and_properties.items():
            # object_dfs[object_type] = extract_features(graph, object_type, properties).dropna()
            imputer = SimpleImputer(strategy='most_frequent')  # Or 'median', 'most_frequent'
            object_dfs[object_type] = extract_features(graph, object_type, properties)
            for property in object_dfs[object_type].columns:
                # Take first element if it is a list
                object_dfs[object_type][property] = object_dfs[object_type][property].apply(lambda x: x[0] if isinstance(x, list) else x)
                # Set boolean to False if null
                object_dfs[object_type][property] = object_dfs[object_type][property].apply(lambda x: False if x == "null" or x is None else x) 
            object_dfs[object_type][:] = imputer.fit_transform(object_dfs[object_type])  # Impute in place

        relationship_df = extract_relationships(graph, relationship_types)

        #print("Filtering dataframes")
        #object_dfs_filtered = filter_dataframes(object_dfs)

        print("Creating heterogenous PyG graph")

        missing_features = {
            object_type: list(set(object_types_and_properties[object_type]) - set(df.columns))
            for object_type, df in object_dfs.items()
        }

        graph_data = create_heterogeneous_graph(object_dfs, relationship_df, object_property_types, missing_features)
        graph_data.y = torch.tensor([1])

        vulnerable_dataset.append(graph_data)
        
torch.save(vulnerable_dataset, "vulnerable_dataset3.pt")

### Safe environments

In [None]:
# Directory for ADSynth datasets
data_dir = ""

safe_dataset = []

#Main generation function
print("===== Start =====")
for filename in os.listdir(data_dir):
    if filename.endswith(".json"):
        print("----- Starting process for new graph -----")
        print(f"Now processing: {filename} ")

        file_path = os.path.join(data_dir, filename)
        print("Clearing database")
        clear_neo4j_database(graph)

        print("Loading json file")
        load_graph_from_json(graph, file_path)

        object_dfs = {}
        relationship_df = pd.DataFrame()
        for object_type, properties in object_types_and_properties.items():
            # object_dfs[object_type] = extract_features(graph, object_type, properties).dropna()
            imputer = SimpleImputer(strategy='most_frequent')  # Or 'median', 'most_frequent'
            object_dfs[object_type] = extract_features(graph, object_type, properties)
            for property in object_dfs[object_type].columns:
                # Take first element if it is a list
                object_dfs[object_type][property] = object_dfs[object_type][property].apply(lambda x: x[0] if isinstance(x, list) else x)
                # Set boolean to False if null
                object_dfs[object_type][property] = object_dfs[object_type][property].apply(lambda x: False if x == "null" or x is None else x) 
            object_dfs[object_type][:] = imputer.fit_transform(object_dfs[object_type])  # Impute in place

        relationship_df = extract_relationships(graph, relationship_types)

        #print("Filtering dataframes")
        #object_dfs_filtered = filter_dataframes(object_dfs)

        print("Creating heterogenous PyG graph")

        missing_features = {
            object_type: list(set(object_types_and_properties[object_type]) - set(df.columns))
            for object_type, df in object_dfs.items()
        }

        graph_data = create_heterogeneous_graph(object_dfs, relationship_df, object_property_types, missing_features)
        graph_data.y = torch.tensor([0])

        safe_dataset.append(graph_data)
        
torch.save(safe_dataset, "safe_dataset3.pt")

### Combined

In [34]:
import torch
from torch_geometric.transforms import BaseTransform

class CustomAddSelfLoops(BaseTransform):
    def __init__(self, attr='edge_weight', fill_value=1.0):
        self.attr = attr
        self.fill_value = fill_value

    def forward(self, data):
        for store in data.edge_stores:
            if store.is_bipartite() or 'edge_index' not in store:
                continue

            # Get the number of nodes for the current node type
            num_nodes = store.size(0)

            # Create self-loop edges (connect each node to itself)
            self_loop_edges = torch.arange(num_nodes).repeat(2, 1) 

            # Concatenate with existing edges
            store.edge_index = torch.cat([store.edge_index, self_loop_edges], dim=1)

            # Add edge attributes (if needed)
            if self.attr is not None:
                # Adjust this based on how you want to fill edge attributes for self-loops
                self_loop_attr = torch.full((num_nodes,), self.fill_value, dtype=torch.float)
                if store.get(self.attr) is not None:
                    store[self.attr] = torch.cat([store[self.attr], self_loop_attr])
                else:
                    store[self.attr] = self_loop_attr
        return data

## Loading Dataset

In [None]:
from torch_geometric.transforms import AddSelfLoops

vulnerable_dataset = torch.load("vulnerable_dataset3.pt")
safe_dataset = torch.load("safe_dataset3.pt")

dataset = vulnerable_dataset + safe_dataset

# Example usage:
# transform = CustomAddSelfLoops()
# dataset = [transform(data) for data in dataset]

print(f'Number of graphs: {len(dataset)}')
print(f'Graph has structure: {dataset[0]}')
print(f'Number of features: {dataset[0].num_features}')
print(f'Number of classes: {len(torch.unique(torch.cat([data.y for data in dataset])))}') 
print(f'Has isolated nodes: {dataset[0].has_isolated_nodes()}')


In [38]:
import torch
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split

torch.manual_seed(1)

# Assuming your dataset has a 'y' attribute for labels
labels = [data.y for data in dataset] 

# Stratified split into train and temp (val + test)
train_dataset, temp_dataset, train_labels, temp_labels = train_test_split(
    dataset, labels, test_size=0.2, stratify=labels, random_state=1
)

# Stratified split of temp into val and test
val_dataset, test_dataset, val_labels, test_labels = train_test_split(
    temp_dataset, temp_labels, test_size=0.5, stratify=temp_labels, random_state=1
)

## Visualising data

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import defaultdict
from collections import Counter  # Import Counter for top-k analysis

def visualize_graph_properties(dataset, top_k_edges=10):  # Add top_k_edges parameter
    # --- Separate Data by Class ---
    vulnerable_data = [data for data in dataset if data.y.item() == 1]
    safe_data = [data for data in dataset if data.y.item() == 0]

    # --- Node Count Distribution (Separate Plots for Each Class) ---
    node_counts_by_type = defaultdict(lambda: {0: [], 1: []})  
    for data in dataset:
        class_label = data.y.item()
        for node_type in data.node_types:
            node_counts_by_type[node_type][class_label].append(data[node_type].num_nodes)

    for class_label, class_name in enumerate(['Safe', 'Vulnerable']):
        plt.figure(figsize=(12, 6))
        for node_type, counts in node_counts_by_type.items():
            sns.histplot(counts[class_label], label=f'{node_type} ({class_name})', kde=True, bins=20)
        plt.title(f'Node Count Distribution by Node Type ({class_name})')
        plt.xlabel('Number of Nodes')
        plt.ylabel('Frequency')
        plt.legend()
        plt.show()

    # --- Edge Count Distribution (Separate Plots for Each Class + Log Scale) ---
    edge_counts_by_type = defaultdict(lambda: {0: [], 1: []})
    for data in dataset:
        class_label = data.y.item()
        for edge_type in data.edge_types:
            edge_counts_by_type[edge_type][class_label].append(data[edge_type].edge_index.shape[1])

    for class_label, class_name in enumerate(['Safe', 'Vulnerable']):
        plt.figure(figsize=(12, 6))
        for edge_type, counts in edge_counts_by_type.items():
            sns.histplot(counts[class_label], label=f'{edge_type} ({class_name})', kde=True, bins=20)
        plt.title(f'Edge Count Distribution by Edge Type ({class_name}) - Log Scale')
        plt.xlabel('Number of Edges')
        plt.ylabel('Frequency')
        plt.xscale('log')
        plt.legend()
        plt.show()


    # --- Edge Count Distribution (Top-k, Difference between classes) ---
    edge_counts_safe = Counter()
    for data in safe_data:
        for edge_type in data.edge_types:
            edge_counts_safe[edge_type] += data[edge_type].edge_index.shape[1]

    edge_counts_vulnerable = Counter()
    for data in vulnerable_data:
        for edge_type in data.edge_types:
            edge_counts_vulnerable[edge_type] += data[edge_type].edge_index.shape[1]

    # Calculate the difference in counts
    edge_count_diff = {
        edge_type: edge_counts_vulnerable.get(edge_type, 0) - edge_counts_safe.get(edge_type, 0)
        for edge_type in set(edge_counts_safe) | set(edge_counts_vulnerable)
    }

    # Sort by absolute difference
    sorted_edge_count_diff = sorted(edge_count_diff.items(), key=lambda item: abs(item[1]), reverse=True)
    top_k_edges_diff = sorted_edge_count_diff[:top_k_edges]

    plt.figure(figsize=(12, 6))
    sns.barplot(x=[str(edge[0]) for edge in top_k_edges_diff], y=[edge[1] for edge in top_k_edges_diff])
    plt.title(f'Top {top_k_edges} Edge Types with Largest Difference in Frequency (Vulnerable - Safe)')
    plt.xlabel('Edge Type')
    plt.ylabel('Frequency Difference')
    plt.xticks(rotation=45, ha='right')
    plt.show()

    # --- Feature Statistics (Separate Violin Plots for Each Class) ---
    all_feature_stats = {0: [], 1: []}
    node_types = {0: [], 1: []}
    for data in dataset:
        class_label = data.y.item()
        for node_type, features in data.x_dict.items():
            all_feature_stats[class_label].extend(features.reshape(-1).tolist())
            node_types[class_label].extend([node_type] * features.numel())

    for class_label, class_name in enumerate(['Safe', 'Vulnerable']):
        plt.figure(figsize=(12, 6))
        sns.violinplot(x=node_types[class_label], y=all_feature_stats[class_label], inner='quartile')
        plt.title(f'Feature Value Distribution by Node Type ({class_name})')
        plt.xlabel('Node Type')
        plt.ylabel('Feature Value')
        plt.xticks(rotation=45)
        plt.show()


    # --- Outlier Detection (Box Plot, Separate by Class) ---
    plt.figure(figsize=(12, 6))
    sns.boxplot(x=node_types[0], y=all_feature_stats[0], color='blue', label='Safe')
    sns.boxplot(x=node_types[1], y=all_feature_stats[1], color='red', label='Vulnerable')
    plt.title('Box Plot for Feature Outlier Detection by Class')
    plt.xlabel('Node Type')
    plt.ylabel('Feature Value')
    plt.xticks(rotation=45)
    plt.legend()
    plt.show()


    # Class Distribution 
    vulnerable_count = sum(data.y.item() for data in dataset)
    safe_count = len(dataset) - vulnerable_count

    plt.figure(figsize=(6, 6))
    plt.pie([vulnerable_count, safe_count], labels=['Vulnerable', 'Safe'], autopct='%1.1f%%', startangle=140)
    plt.title('Class Distribution')
    plt.show()

    # Example usage:
    # Assuming you have a list of `HeteroData` objects called `dataset`
visualize_graph_properties(dataset)

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

def visualize_heterogeneous_graph(data, node_type_colors):
    """
    Visualizes a heterogeneous graph using networkx and matplotlib.

    Args:
        data (HeteroData): The heterogeneous graph data.
        node_type_colors (dict): A dictionary mapping node types to colors.
    """

    G = nx.DiGraph()  # Create a directed graph

    # Add nodes with labels and colors
    for node_type in data.node_types:
        G.add_nodes_from(range(data[node_type].num_nodes), label=node_type)

    # Add edges
    for edge_type in data.edge_types:
        src_type, _, dst_type = edge_type
        edge_index = data[edge_type].edge_index
        for src, dst in edge_index.t().tolist():
            G.add_edge(src, dst)

    # Increase figure size
    plt.figure(figsize=(150, 150))  # Adjust figsize as needed

    # Position nodes using a suitable layout with more space
    pos = nx.spring_layout(G, k=0.8, iterations=50)  # Adjust k and iterations

    # Draw nodes with colors
    for node_type in data.node_types:
        nx.draw_networkx_nodes(
            G,
            pos,
            nodelist=[node for node, label in G.nodes(data='label') if label == node_type],
            node_color=node_type_colors[node_type],
            label=node_type,
        )

    # Draw edges
    nx.draw_networkx_edges(G, pos)

    # Add labels to nodes
    nx.draw_networkx_labels(G, pos, font_size=8)

    plt.title("Heterogeneous Graph Visualization")
    plt.axis('off')
    plt.show()


# Example usage:
# Assuming you have a `HeteroData` object called `data`

# Define colors for each node type
node_type_colors = {
    'Domain': 'pink',
    'User': 'red',
    'Computer': 'blue',
    'Group': 'green',
    'OU': 'purple',
    'GPO': 'orange',
    'Container': 'yellow',
}

# Visualize the graph
visualize_heterogeneous_graph(train_dataset[2], node_type_colors)

## GNN 

### Model

#### SAGEConv

In [22]:
import torch
from torch_geometric.nn import SAGEConv, HeteroConv, GATConv

class HeterogeneousGNN(torch.nn.Module):
    def __init__(self, hidden_channels, metadata):
        super().__init__()

        # Define convolutional layers for each edge type
        self.conv = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_channels)
            for edge_type in metadata[1]
        }, aggr='mean')

        # Linear layer for final classification
        self.lin = torch.nn.Linear(hidden_channels*len(metadata[0]), 1)  # 2 output classes

    def forward(self, x_dict, edge_index_dict):

        # 1. Heterogeneous Convolution
        x_dict = self.conv(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}

        # 2. Aggregation (e.g., mean all node embeddings)
        x_list = []
        for object_type in x_dict:
            x_list.append(x_dict[object_type].mean(dim=0))
        x_list
        x = torch.cat(x_list, dim=0) 

        # 3. Linear layer for classification
        x = self.lin(x)

        # 4. Apply activation function
        #x = torch.relu(x)
        x = torch.sigmoid(x)
        return x


#### SAGEConv 2 layers?

In [48]:
class HeterogeneousGNN(torch.nn.Module):
    def __init__(self, hidden_channels, metadata, num_layers=2):
        super().__init__()
        
        # Create multiple convolution layers
        self.convs = torch.nn.ModuleList([
            HeteroConv({
                edge_type: SAGEConv((-1, -1), hidden_channels)
                for edge_type in metadata[1]
            }, aggr='mean') for _ in range(num_layers)
        ])
        
        # Linear layer for final classification
        self.lin = torch.nn.Linear(hidden_channels*len(metadata[0]), 1)

    def forward(self, x_dict, edge_index_dict):
        # Multi-layer heterogeneous convolution
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        
        # Aggregation
        x_list = [x_dict[object_type].mean(dim=0) for object_type in x_dict]
        x = torch.cat(x_list, dim=0)
        
        # Linear layer and final activation
        x = self.lin(x)
        x = torch.sigmoid(x)
        return x

#### SAGEConv 2 layers

In [128]:
import torch
from torch_geometric.nn import SAGEConv, HeteroConv, GATConv
import torch.nn.functional as F

class HeterogeneousGNN(torch.nn.Module):
    def __init__(self, hidden_channels, metadata):
        super().__init__()

        self.conv1 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_channels)
            for edge_type in metadata[1]
        }, aggr='sum')

        self.conv2 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_channels)
            for edge_type in metadata[1]
        }, aggr='sum')

        self.lin = torch.nn.Linear(hidden_channels*4, 1) 

    def forward(self, x_dict, edge_index_dict):
        # 1. First Convolution 
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: F.dropout(x, p=0.1, training=self.training) for key, x in x_dict.items()}
        print("x_dict after conv1:", {key: x.shape if x is not None else None for key, x in x_dict_conv1.items()})
  

        # 2. Second Convolution (with handling for no-edge nodes)
        x_dict_conv2 = self.conv2(x_dict, edge_index_dict)

        for key in x_dict: 
            if key in x_dict_conv2:
                x_dict[key] = x_dict_conv2[key] 
            else:
                # If no output from conv2, keep the original features
                x_dict[key] = x_dict[key]  

        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: F.dropout(x, p=0.1, training=self.training) for key, x in x_dict.items()}  

        # 3. Aggregation (include GPO now)
        x = torch.cat([x_dict['User'].mean(dim=0), 
                       x_dict['Computer'].mean(dim=0),
                       x_dict['Group'].mean(dim=0),
                       x_dict['OU'].mean(dim=0),
                       x_dict['GPO'].mean(dim=0)], dim=0)  # Include GPO

        # 4. Linear and Activation
        x = self.lin(x)
        x = torch.sigmoid(x)
        return x

#### GATConv

In [122]:
import torch
from torch_geometric.nn import SAGEConv, HeteroConv, GATConv

class HeterogeneousGNN(torch.nn.Module):
    def __init__(self, hidden_channels, metadata):
        super().__init__()

        # Define convolutional layers for each edge type
        self.conv = HeteroConv({
            edge_type: GATConv((-1, -1), hidden_channels, add_self_loops=False)
            for edge_type in metadata[1]
        }, aggr='sum')

        # Linear layer for final classification
        self.lin = torch.nn.Linear(hidden_channels*4, 1)  # 2 output classes

    def forward(self, x_dict, edge_index_dict):
        # 1. Heterogeneous Convolution
        x_dict = self.conv(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        # Apply dropout to the output of each convolution
        #x_dict = {key: F.dropout(x, p=0.9, training=self.training) for key, x in x_dict.items()}  

        

        # 2. Aggregation (e.g., mean all node embeddings)
        x = torch.cat([x_dict['User'].mean(dim=0), 
                       x_dict['Computer'].mean(dim=0),
                       x_dict['Group'].mean(dim=0),
                       x_dict['OU'].mean(dim=0),
                       #x_dict['GPO'].mean(dim=0)
                       ], dim=0)

        # 3. Linear layer for classification
        x = self.lin(x)

        # 4. Apply activation function
        #x = torch.relu(x)
        x = torch.sigmoid(x)
        return x


### Grid Search

In [None]:
import torch
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import itertools
from sklearn.metrics import f1_score
import numpy as np
import os

def plot_learning_curves(history, params, save_dir='hetero_learning_curves'):
    """Plot and save learning curves for a parameter combination"""
    plt.figure(figsize=(12, 4))
    
    # Loss subplot
    plt.subplot(1, 2, 1)
    plt.plot(history['train_losses'], label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    
    # Accuracy subplot
    plt.subplot(1, 2, 2)
    plt.plot(history['train_accuracies'], label='Training Accuracy')
    plt.plot(history['val_accuracies'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    # Create title with parameters
    plt.suptitle(f'Hidden: {params["hidden_channels"]}, LR: {params["learning_rate"]}, Layers: {params["num_layers"]}')
    
    os.makedirs(save_dir, exist_ok=True)
    filename = f'hidden_{params["hidden_channels"]}_lr_{params["learning_rate"]}_numlayers_{params["num_layers"]}.png'
    plt.savefig(os.path.join(save_dir, filename), bbox_inches='tight', dpi=300)
    plt.close()

def train_model(model, train_dataset, val_dataset, learning_rate, num_epochs=50, warmup_epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.BCELoss()
    
    best_metrics = {
        'train_loss': float('inf'),
        'train_acc': 0,
        'val_acc': 0,
        'val_f1': 0,
        'epoch': 0,
        'is_warmup': True
    }
    
    history = {
        'train_losses': [],
        'train_accuracies': [],
        'val_accuracies': [],
        'val_f1_scores': []
    }
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        epoch_loss = 0
        correct_train = 0
        total_train = 0
        
        for i in range(len(train_dataset)):
            data_object = train_dataset[i]
            optimizer.zero_grad()
            out = model(data_object.x_dict, data_object.edge_index_dict)
            y = data_object.y.float()
            
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pred_train = (out > 0.5).float()
            correct_train += int((pred_train == y).sum())
            total_train += 1
        
        train_loss = epoch_loss / len(train_dataset)
        train_acc = correct_train / total_train
        
        # Validation
        model.eval()
        correct_val = 0
        total_val = 0
        val_predictions = []
        val_labels = []
        
        with torch.no_grad():
            for i in range(len(val_dataset)):
                data_object = val_dataset[i]
                out = model(data_object.x_dict, data_object.edge_index_dict)
                y = data_object.y.float()
                
                pred_val = (out > 0.5).float()
                correct_val += int((pred_val == y).sum())
                total_val += 1
                
                val_predictions.extend(pred_val.cpu().numpy())
                val_labels.extend(y.cpu().numpy())
        
        val_acc = correct_val / total_val
        val_f1 = f1_score(val_labels, val_predictions, average='binary')
        
        # Store history
        history['train_losses'].append(train_loss)
        history['train_accuracies'].append(train_acc)
        history['val_accuracies'].append(val_acc)
        history['val_f1_scores'].append(val_f1)
        
        # Update best metrics after warmup
        if epoch >= warmup_epochs:
            if best_metrics['is_warmup']:
                best_metrics = {
                    'train_loss': train_loss,
                    'train_acc': train_acc,
                    'val_acc': val_acc,
                    'val_f1': val_f1,
                    'epoch': epoch,
                    'is_warmup': False
                }
            else:
                if val_f1 > best_metrics['val_f1']:
                    best_metrics.update({
                        'train_loss': train_loss,
                        'train_acc': train_acc,
                        'val_acc': val_acc,
                        'val_f1': val_f1,
                        'epoch': epoch
                    })
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
            print(f"Val - Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
    
    return best_metrics, history

def grid_search(train_dataset, val_dataset, metadata, warmup_epochs=10):
    param_grid = {
        'hidden_channels': [32, 64, 128],
        'learning_rate': [0.0001, 0.0005, 0.001],
        'num_layers': [1, 2, 3]
    }
    
    param_combinations = list(itertools.product(
        param_grid['hidden_channels'],
        param_grid['learning_rate'],
        param_grid['num_layers']
    ))
    
    results = []
    best_overall = {
        'train_loss': float('inf'),
        'train_acc': 0,
        'val_acc': 0,
        'val_f1': 0,
        'params': None,
        'model': None
    }
    
    total_combinations = len(param_combinations)
    for idx, (hidden_channels, lr, layers) in enumerate(param_combinations, 1):
        print(f"\nTesting combination {idx}/{total_combinations}:")
        print(f"Hidden Channels: {hidden_channels}, Learning Rate: {lr}, Layers: {layers}")
        
        model = HeterogeneousGNN(hidden_channels=hidden_channels, metadata=metadata, num_layers=layers)
        
        best_metrics, history = train_model(
            model,
            train_dataset,
            val_dataset,
            learning_rate=lr,
            warmup_epochs=warmup_epochs
        )
        
        current_params = {
            'hidden_channels': hidden_channels,
            'learning_rate': lr,
            'num_layers': layers,
            'train_loss': best_metrics['train_loss'],
            'train_acc': best_metrics['train_acc'],
            'val_acc': best_metrics['val_acc'],
            'val_f1': best_metrics['val_f1'],
            'best_epoch': best_metrics['epoch']
        }
        
        plot_learning_curves(history, current_params)
        results.append(current_params)
        
        # Change metric to val_acc instead of val_f1
        if current_params['val_acc'] > best_overall['val_acc']:
            best_overall.update({
                'train_loss': current_params['train_loss'],
                'train_acc': current_params['train_acc'],
                'val_acc': current_params['val_acc'],
                'val_f1': current_params['val_f1'],
                'params': current_params,
                'model': model.state_dict()
            })
        
        # Print current best parameters
        print("\nCurrent Best Parameters:")
        print(f"Hidden Channels: {best_overall['params']['hidden_channels']}")
        print(f"Learning Rate: {best_overall['params']['learning_rate']}")
        print(f"Number of Layers: {best_overall['params']['num_layers']}")
        print(f"Validation Accuracy: {best_overall['val_acc']:.4f}")
    
    return results, best_overall['params'], best_overall['model']

# Example usage remains the same
results, best_params, best_model = grid_search(
    train_dataset,
    val_dataset,
    metadata,
    warmup_epochs=10
)
print("\nBest Parameters:")
print(f"Hidden Channels: {best_params['hidden_channels']}")
print(f"Learning Rate: {best_params['learning_rate']}")
print(f"Number of Layers: {best_params['num_layers']}")
print(f"Training Loss: {best_params['train_loss']:.4f}")
print(f"Training Accuracy: {best_params['train_acc']:.4f}")
print(f"Validation Accuracy: {best_params['val_acc']:.4f}")
print(f"Validation F1 Score: {best_params['val_f1']:.4f}")
print(f"Best Epoch: {best_params['best_epoch']}")

### Training

In [None]:
import torch
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt

hidden_channels = 32
learning_rate = 0.0005
num_epochs = 40
num_layers = 3
        
model = HeterogeneousGNN(hidden_channels=hidden_channels, metadata=metadata, num_layers=num_layers)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.BCELoss()


# Lists to store loss and accuracy values for plotting
train_losses = []
train_accuracies = []
test_accuracies = []

for epoch in range(num_epochs):
    # Training
    model.train()
    epoch_loss = 0
    correct_train = 0
    total_train = 0
    

    for i in range(len(train_dataset)):
        data_object = train_dataset[i]
        optimizer.zero_grad()
        out = model(data_object.x_dict, data_object.edge_index_dict)
        y = data_object.y.float()

        # --- Debug statements ---
        #print(f"Epoch {epoch+1}, Data point {i+1}")
        #print(f"  Output shape: {out.shape}, Output values: {out}")
        #print(f"  Target shape: {y.shape}, Target values: {y}")
        # --- End debug statements ---

        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()  # Accumulate loss

        # Calculate training accuracy (with thresholding)
        pred_train = (out > 0.5).float()  
        correct_train += int((pred_train == y).sum())
        total_train += 1

        # --- Debug statements ---
        #print(f"  Loss: {loss.item():.4f}")
        #for name, param in model.named_parameters():
        #    if param.grad is not None:
        #        print(f"  Gradient norm of {name}: {param.grad.norm()}")
        # --- End debug statements ---
        
    train_accuracy = correct_train / total_train
    train_accuracies.append(train_accuracy)
    train_losses.append(epoch_loss / len(train_dataset))  # Average loss

    # Testing
    model.eval()
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for i in range(len(val_dataset)):
            data_object = val_dataset[i]
            out = model(data_object.x_dict, data_object.edge_index_dict)
            y = data_object.y.float()

            pred_val = (out > 0.5).float()
            correct_val += int((pred_val == y).sum())
            total_val += 1 # Or data_object.num_graphs

    val_accuracy = correct_val / total_val
    test_accuracies.append(val_accuracy)

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, "
          f"Train Acc: {train_accuracy:.4f}, Validation Acc: {val_accuracy:.4f}")

# Plot the loss and accuracy curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(test_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
model.eval()
correct_test = 0
total_test = 0
with torch.no_grad():
    for i in range(len(test_dataset)):  # Iterate over each data point in the test set
        data_object = test_dataset[i]
        out = model(data_object.x_dict, data_object.edge_index_dict)
        y = data_object.y.float().unsqueeze(1)  # Add dimension for single-example "batch"

        pred_test = (out > 0.5).float()
        correct_test += int((pred_test == y).sum())
        total_test += 1

test_accuracy = correct_test / total_test
print(f"Test Accuracy: {test_accuracy:.4f}")

In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report

model.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for i in range(len(val_dataset)):
        data_object = val_dataset[i]
        out = model(data_object.x_dict, data_object.edge_index_dict)
        y = data_object.y.float().unsqueeze(1)

        pred_test = (out > 0.5).int() # Store as integers for confusion matrix
        y_true.extend(y.numpy().flatten().tolist()) # Move to cpu, numpy, flatten, to list
        y_pred.extend(pred_test.numpy().flatten().tolist()) # Move to cpu, numpy, flatten, to list

# Convert lists to numpy arrays
y_true = np.array(y_true)
y_pred = np.array(y_pred)

# Calculate confusion matrix
cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:")
print(cm)

# Calculate classification report (includes precision, recall, F1-score)
cr = classification_report(y_true, y_pred, target_names=['Safe', 'Vulnerable']) # Add target names for better readability
print("\nClassification Report:")
print(cr)

# Calculate overall accuracy (can cross-check with the report)
accuracy = np.sum(y_true == y_pred) / len(y_true)
print(f"\nOverall Accuracy: {accuracy:.4f}")

# Extract TP, TN, FP, FN for manual metric calculation if needed
TN, FP, FN, TP = cm.ravel()

#Manual calculation of metrics for verification if needed
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1: {f1:.4f}")

### Real life graphs

In [329]:
def test_single_graph(model, graph_data):
    model.eval()
    with torch.no_grad():
        out = model(graph_data.x_dict, graph_data.edge_index_dict)
        pred_test = (out > 0.5).float()
        if pred_test == 1:
            print(f"This network is vulnerable!")
        elif pred_test == 0:
            print(f"This network is safe!")
            